From db336c227d5dfa8a277a28b6e1cea75ebf897a24 Mon Sep 17 00:00:00 2001 From: Norman Jordan Date: Mon, 8 Apr 2024 14:57:05 -0700 Subject: [PATCH] [CALCITE-6313] Add POWER function for PostgreSQL * The existing power function is moved to all non PostgreSQL libraries * The new power function is only for PostgreSQL * The new function returns a decimal if any argument is a decimal --- .../adapter/enumerable/RexImpTable.java | 2 ++ .../calcite/sql/fun/SqlLibraryOperators.java | 17 ++++++++++- .../apache/calcite/sql/type/ReturnTypes.java | 26 ++++++++++++++++ .../apache/calcite/util/BuiltInMethod.java | 1 + site/_docs/reference.md | 1 + .../apache/calcite/test/SqlOperatorTest.java | 30 +++++++++++++++++++ 6 files changed, 76 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 2ff47db21940..dea65b951b94 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -232,6 +232,7 @@ import static org.apache.calcite.sql.fun.SqlLibraryOperators.PARSE_TIMESTAMP; import static org.apache.calcite.sql.fun.SqlLibraryOperators.PARSE_URL; import static org.apache.calcite.sql.fun.SqlLibraryOperators.POW; +import static org.apache.calcite.sql.fun.SqlLibraryOperators.POWER_PG; import static org.apache.calcite.sql.fun.SqlLibraryOperators.RANDOM; import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP; import static org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_CONTAINS; @@ -644,6 +645,7 @@ Builder populate() { defineMethod(MOD, BuiltInMethod.MOD.method, NullPolicy.STRICT); defineMethod(EXP, BuiltInMethod.EXP.method, NullPolicy.STRICT); defineMethod(POWER, BuiltInMethod.POWER.method, NullPolicy.STRICT); + defineMethod(POWER_PG, BuiltInMethod.POWER_PG.method, NullPolicy.STRICT); defineMethod(ABS, BuiltInMethod.ABS.method, NullPolicy.STRICT); defineMethod(LOG2, BuiltInMethod.LOG2.method, NullPolicy.STRICT); diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 6996ecb3160a..fe94fecd7caf 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -2216,7 +2216,22 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding @LibraryOperator(libraries = {BIG_QUERY, SPARK}) public static final SqlFunction POW = - SqlStdOperatorTable.POWER.withName("POW"); + SqlBasicFunction.create("POW", + ReturnTypes.DOUBLE_NULLABLE, + OperandTypes.NUMERIC_NUMERIC, + SqlFunctionCategory.NUMERIC); + + /** The {@code POWER(numeric, numeric)} function. + * + *

The return type is {@code DECIMAL} if either argument is a + * {@code DECIMAL}. In all other cases, the return type is a double. + */ + @LibraryOperator(libraries = { POSTGRESQL }) + public static final SqlFunction POWER_PG = + SqlBasicFunction.create("POWER", + ReturnTypes.DECIMAL_OR_DOUBLE_NULLABLE, + OperandTypes.NUMERIC_NUMERIC, + SqlFunctionCategory.NUMERIC); /** The "TRUNC(numeric1 [, integer2])" function. Identical to the standard TRUNCATE * function except the return type should be a double if numeric1 is an integer. */ diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index 386ed96fb915..eb7f9e447179 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -773,6 +773,32 @@ public static SqlCall stripSeparator(SqlCall call) { return null; }; + /** + * Type-inference strategy that returns DECIMAL if any of the arguments are DECIMAL. It + * will return DOUBLE in all other cases. + */ + public static final SqlReturnTypeInference DECIMAL_OR_DOUBLE = opBinding -> { + boolean haveDecimal = false; + for (int i = 0; i < opBinding.getOperandCount(); i++) { + if (SqlTypeUtil.isDecimal(opBinding.getOperandType(i))) { + haveDecimal = true; + break; + } + } + + if (haveDecimal) { + return opBinding.getTypeFactory().createSqlType( + SqlTypeName.DECIMAL, + 17); + } else { + return RelDataTypeImpl.proto(SqlTypeName.DOUBLE, false) + .apply(opBinding.getTypeFactory()); + } + }; + + public static final SqlReturnTypeInference DECIMAL_OR_DOUBLE_NULLABLE = + DECIMAL_OR_DOUBLE.andThen(SqlTypeTransforms.TO_NULLABLE); + /** * Type-inference strategy whereby the result type of a call is * {@link #DECIMAL_SCALE0} with a fallback to {@link #ARG0} This rule diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 2e72df1be651..31862d0c14a3 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -374,6 +374,7 @@ public enum BuiltInMethod { EXP(SqlFunctions.class, "exp", double.class), MOD(SqlFunctions.class, "mod", long.class, long.class), POWER(SqlFunctions.class, "power", double.class, double.class), + POWER_PG(SqlFunctions.class, "power", BigDecimal.class, BigDecimal.class), REPEAT(SqlFunctions.class, "repeat", String.class, int.class), SPACE(SqlFunctions.class, "space", int.class), SPLIT(SqlFunctions.class, "split", String.class), diff --git a/site/_docs/reference.md b/site/_docs/reference.md index e9d175a259f0..8a23f9338c2a 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -2820,6 +2820,7 @@ In the following: | b | PARSE_TIMESTAMP(format, string[, timeZone]) | Uses format specified by *format* to convert *string* representation of timestamp to a TIMESTAMP WITH LOCAL TIME ZONE value in *timeZone* | h s | PARSE_URL(urlString, partToExtract [, keyToExtract] ) | Returns the specified *partToExtract* from the *urlString*. Valid values for *partToExtract* include HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, and USERINFO. *keyToExtract* specifies which query to extract | b s | POW(numeric1, numeric2) | Returns *numeric1* raised to the power *numeric2* +| b c h q m o f s p | POWER(numeric1, numeric2) | Returns *numeric1* raised to the power of *numeric2* | p | RANDOM() | Generates a random double between 0 and 1 inclusive | s | REGEXP(string, regexp) | Equivalent to `string1 RLIKE string2` | b | REGEXP_CONTAINS(string, regexp) | Returns whether *string* is a partial match for the *regexp* diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index d0e529944d2b..1fbccf156914 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -47,6 +47,7 @@ import org.apache.calcite.sql.dialect.AnsiSqlDialect; import org.apache.calcite.sql.fun.LibraryOperator; import org.apache.calcite.sql.fun.SqlLibrary; +import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; @@ -63,6 +64,7 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.util.SqlOperatorTables; import org.apache.calcite.sql.util.SqlString; import org.apache.calcite.sql.validate.SqlConformanceEnum; import org.apache.calcite.sql.validate.SqlNameMatchers; @@ -6483,6 +6485,9 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { final SqlOperatorFixture f = fixture(); f.setFor(SqlStdOperatorTable.POWER, VmName.EXPAND); f.checkScalarApprox("power(2,-2)", "DOUBLE NOT NULL", isExactly("0.25")); + f.checkScalarApprox("power(cast(2 as decimal), cast(-2 as decimal))", + "DOUBLE NOT NULL", + isExactly("0.25")); f.checkNull("power(cast(null as integer),2)"); f.checkNull("power(2,cast(null as double))"); @@ -6492,6 +6497,31 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { false); } + @Test void testPowerDecimalFunc() { + final SqlOperatorFixture f = fixture() + .withOperatorTable( + SqlOperatorTables.chain( + SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable( + ImmutableList.of(SqlLibrary.POSTGRESQL, SqlLibrary.ALL), + false), + SqlStdOperatorTable.instance())) + .setFor(SqlLibraryOperators.POWER_PG); + f.checkScalarApprox("power(cast(2 as decimal), cast(-2 as decimal))", + "DECIMAL(17, 0) NOT NULL", + isExactly("0.25")); + f.checkScalarApprox("power(cast(2 as decimal), -2)", + "DECIMAL(17, 0) NOT NULL", + isExactly("0.25")); + f.checkScalarApprox("power(2, cast(-2 as decimal))", + "DECIMAL(17, 0) NOT NULL", + isExactly("0.25")); + f.checkScalarApprox("power(2, -2)", "DOUBLE NOT NULL", isExactly("0.25")); + f.checkScalarApprox("power(CAST(0.25 AS DOUBLE), CAST(0.5 AS DOUBLE))", + "DOUBLE NOT NULL", isExactly("0.5")); + f.checkNull("power(null, -2)"); + f.checkNull("power(2, null)"); + } + @Test void testSqrtFunc() { final SqlOperatorFixture f = fixture(); f.setFor(SqlStdOperatorTable.SQRT, VmName.EXPAND);