Skip to content

Commit

Permalink
[CALCITE-6313] Add POWER function for PostgreSQL
Browse files Browse the repository at this point in the history
* The existing power function is moved to all non PostgreSQL libraries
* The new power function is only for PostgreSQL
* The new function returns a decimal if any argument is a decimal
  • Loading branch information
normanj-bitquill committed Apr 26, 2024
1 parent 1566663 commit db336c2
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.PARSE_TIMESTAMP;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.PARSE_URL;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.POW;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.POWER_PG;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.RANDOM;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_CONTAINS;
Expand Down Expand Up @@ -644,6 +645,7 @@ Builder populate() {
defineMethod(MOD, BuiltInMethod.MOD.method, NullPolicy.STRICT);
defineMethod(EXP, BuiltInMethod.EXP.method, NullPolicy.STRICT);
defineMethod(POWER, BuiltInMethod.POWER.method, NullPolicy.STRICT);
defineMethod(POWER_PG, BuiltInMethod.POWER_PG.method, NullPolicy.STRICT);
defineMethod(ABS, BuiltInMethod.ABS.method, NullPolicy.STRICT);
defineMethod(LOG2, BuiltInMethod.LOG2.method, NullPolicy.STRICT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2216,7 +2216,22 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding

@LibraryOperator(libraries = {BIG_QUERY, SPARK})
public static final SqlFunction POW =
SqlStdOperatorTable.POWER.withName("POW");
SqlBasicFunction.create("POW",
ReturnTypes.DOUBLE_NULLABLE,
OperandTypes.NUMERIC_NUMERIC,
SqlFunctionCategory.NUMERIC);

/** The {@code POWER(numeric, numeric)} function.
*
* <p>The return type is {@code DECIMAL} if either argument is a
* {@code DECIMAL}. In all other cases, the return type is a double.
*/
@LibraryOperator(libraries = { POSTGRESQL })
public static final SqlFunction POWER_PG =
SqlBasicFunction.create("POWER",
ReturnTypes.DECIMAL_OR_DOUBLE_NULLABLE,
OperandTypes.NUMERIC_NUMERIC,
SqlFunctionCategory.NUMERIC);

/** The "TRUNC(numeric1 [, integer2])" function. Identical to the standard <code>TRUNCATE</code>
* function except the return type should be a double if numeric1 is an integer. */
Expand Down
26 changes: 26 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,32 @@ public static SqlCall stripSeparator(SqlCall call) {
return null;
};

/**
* Type-inference strategy that returns DECIMAL if any of the arguments are DECIMAL. It
* will return DOUBLE in all other cases.
*/
public static final SqlReturnTypeInference DECIMAL_OR_DOUBLE = opBinding -> {
boolean haveDecimal = false;
for (int i = 0; i < opBinding.getOperandCount(); i++) {
if (SqlTypeUtil.isDecimal(opBinding.getOperandType(i))) {
haveDecimal = true;
break;
}
}

if (haveDecimal) {
return opBinding.getTypeFactory().createSqlType(
SqlTypeName.DECIMAL,
17);
} else {
return RelDataTypeImpl.proto(SqlTypeName.DOUBLE, false)
.apply(opBinding.getTypeFactory());
}
};

public static final SqlReturnTypeInference DECIMAL_OR_DOUBLE_NULLABLE =
DECIMAL_OR_DOUBLE.andThen(SqlTypeTransforms.TO_NULLABLE);

/**
* Type-inference strategy whereby the result type of a call is
* {@link #DECIMAL_SCALE0} with a fallback to {@link #ARG0} This rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ public enum BuiltInMethod {
EXP(SqlFunctions.class, "exp", double.class),
MOD(SqlFunctions.class, "mod", long.class, long.class),
POWER(SqlFunctions.class, "power", double.class, double.class),
POWER_PG(SqlFunctions.class, "power", BigDecimal.class, BigDecimal.class),
REPEAT(SqlFunctions.class, "repeat", String.class, int.class),
SPACE(SqlFunctions.class, "space", int.class),
SPLIT(SqlFunctions.class, "split", String.class),
Expand Down
1 change: 1 addition & 0 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2820,6 +2820,7 @@ In the following:
| b | PARSE_TIMESTAMP(format, string[, timeZone]) | Uses format specified by *format* to convert *string* representation of timestamp to a TIMESTAMP WITH LOCAL TIME ZONE value in *timeZone*
| h s | PARSE_URL(urlString, partToExtract [, keyToExtract] ) | Returns the specified *partToExtract* from the *urlString*. Valid values for *partToExtract* include HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, and USERINFO. *keyToExtract* specifies which query to extract
| b s | POW(numeric1, numeric2) | Returns *numeric1* raised to the power *numeric2*
| b c h q m o f s p | POWER(numeric1, numeric2) | Returns *numeric1* raised to the power of *numeric2*
| p | RANDOM() | Generates a random double between 0 and 1 inclusive
| s | REGEXP(string, regexp) | Equivalent to `string1 RLIKE string2`
| b | REGEXP_CONTAINS(string, regexp) | Returns whether *string* is a partial match for the *regexp*
Expand Down
30 changes: 30 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.calcite.sql.dialect.AnsiSqlDialect;
import org.apache.calcite.sql.fun.LibraryOperator;
import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
Expand All @@ -63,6 +64,7 @@
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.util.SqlOperatorTables;
import org.apache.calcite.sql.util.SqlString;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.sql.validate.SqlNameMatchers;
Expand Down Expand Up @@ -6483,6 +6485,9 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.POWER, VmName.EXPAND);
f.checkScalarApprox("power(2,-2)", "DOUBLE NOT NULL", isExactly("0.25"));
f.checkScalarApprox("power(cast(2 as decimal), cast(-2 as decimal))",
"DOUBLE NOT NULL",
isExactly("0.25"));
f.checkNull("power(cast(null as integer),2)");
f.checkNull("power(2,cast(null as double))");

Expand All @@ -6492,6 +6497,31 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
false);
}

@Test void testPowerDecimalFunc() {
final SqlOperatorFixture f = fixture()
.withOperatorTable(
SqlOperatorTables.chain(
SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable(
ImmutableList.of(SqlLibrary.POSTGRESQL, SqlLibrary.ALL),
false),
SqlStdOperatorTable.instance()))
.setFor(SqlLibraryOperators.POWER_PG);
f.checkScalarApprox("power(cast(2 as decimal), cast(-2 as decimal))",
"DECIMAL(17, 0) NOT NULL",
isExactly("0.25"));
f.checkScalarApprox("power(cast(2 as decimal), -2)",
"DECIMAL(17, 0) NOT NULL",
isExactly("0.25"));
f.checkScalarApprox("power(2, cast(-2 as decimal))",
"DECIMAL(17, 0) NOT NULL",
isExactly("0.25"));
f.checkScalarApprox("power(2, -2)", "DOUBLE NOT NULL", isExactly("0.25"));
f.checkScalarApprox("power(CAST(0.25 AS DOUBLE), CAST(0.5 AS DOUBLE))",
"DOUBLE NOT NULL", isExactly("0.5"));
f.checkNull("power(null, -2)");
f.checkNull("power(2, null)");
}

@Test void testSqrtFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.SQRT, VmName.EXPAND);
Expand Down

0 comments on commit db336c2

Please sign in to comment.