diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 2ff47db21940..0ad0521c7d66 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -211,6 +211,7 @@ import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG2; import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND; import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR; +import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG_MS; import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD; import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP; import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_CONCAT; @@ -645,12 +646,14 @@ Builder populate() { defineMethod(EXP, BuiltInMethod.EXP.method, NullPolicy.STRICT); defineMethod(POWER, BuiltInMethod.POWER.method, NullPolicy.STRICT); defineMethod(ABS, BuiltInMethod.ABS.method, NullPolicy.STRICT); - defineMethod(LOG2, BuiltInMethod.LOG2.method, NullPolicy.STRICT); map.put(LN, new LogImplementor()); map.put(LOG, new LogImplementor()); map.put(LOG10, new LogImplementor()); + map.put(LOG_MS, new LogMSImplementor()); + map.put(LOG2, new LogMSImplementor()); + defineReflective(RAND, BuiltInMethod.RAND.method, BuiltInMethod.RAND_SEED.method); defineReflective(RAND_INTEGER, BuiltInMethod.RAND_INTEGER.method, @@ -4166,13 +4169,51 @@ private static List args(RexCall call, switch (call.getOperator().getName()) { case "LOG": if (argValueList.size() == 2) { - return list.append(argValueList.get(1)); + return list.append(argValueList.get(1)).append(Expressions.constant(0)); + } + // fall through + case "LN": + return list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(0)); + case "LOG10": + return list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(0)); + default: + throw new AssertionError("Operator not found: " + call.getOperator()); + } + } + } + + /** Implementor for the {@code LN}, {@code LOG}, {@code LOG2} and {@code LOG10} operators + * on Mysql and Spark library + * + *

Handles all logarithm functions using log rules to determine the + * appropriate base (i.e. base e for LN). + */ + private static class LogMSImplementor extends AbstractRexCallImplementor { + LogMSImplementor() { + super("logMS", NullPolicy.STRICT, true); + } + + @Override Expression implementSafe(final RexToLixTranslator translator, + final RexCall call, final List argValueList) { + return Expressions.call(BuiltInMethod.LOG.method, args(call, argValueList)); + } + + private static List args(RexCall call, + List argValueList) { + Expression operand0 = argValueList.get(0); + final Expressions.FluentList list = Expressions.list(operand0); + switch (call.getOperator().getName()) { + case "LOG": + if (argValueList.size() == 2) { + return list.append(argValueList.get(1)).append(Expressions.constant(1)); } // fall through case "LN": - return list.append(Expressions.constant(Math.exp(1))); + return list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(1)); + case "LOG2": + return list.append(Expressions.constant(2)).append(Expressions.constant(1)); case "LOG10": - return list.append(Expressions.constant(BigDecimal.TEN)); + return list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(1)); default: throw new AssertionError("Operator not found: " + call.getOperator()); } diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java index 8fbdfb11a487..6a390bb2bdb9 100644 --- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java +++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java @@ -2788,36 +2788,37 @@ public static double power(BigDecimal b0, BigDecimal b1) { // LN, LOG, LOG10, LOG2 /** SQL {@code LOG(number, number2)} function applied to double values. */ - public static double log(double d0, double d1) { - return Math.log(d0) / Math.log(d1); + public static @Nullable Double log(double number, double number2, int nullFlag) { + if (nullFlag == 1 && number <= 0) { + return null; + } + return Math.log(number) / Math.log(number2); } /** SQL {@code LOG(number, number2)} function applied to * double and BigDecimal values. */ - public static double log(double d0, BigDecimal d1) { - return Math.log(d0) / Math.log(d1.doubleValue()); + public static @Nullable Double log(double number, BigDecimal number2, int nullFlag) { + if (nullFlag == 1 && number <= 0) { + return null; + } + return Math.log(number) / Math.log(number2.doubleValue()); } /** SQL {@code LOG(number, number2)} function applied to * BigDecimal and double values. */ - public static double log(BigDecimal d0, double d1) { - return Math.log(d0.doubleValue()) / Math.log(d1); + public static @Nullable Double log(BigDecimal number, double number2, int nullFlag) { + if (nullFlag == 1 && number.doubleValue() <= 0) { + return null; + } + return Math.log(number.doubleValue()) / Math.log(number2); } /** SQL {@code LOG(number, number2)} function applied to double values. */ - public static double log(BigDecimal d0, BigDecimal d1) { - return Math.log(d0.doubleValue()) / Math.log(d1.doubleValue()); - } - - /** SQL {@code LOG2(number)} function applied to double values. */ - public static @Nullable Double log2(double number) { - return (number <= 0) ? null : log(number, 2); - } - - /** SQL {@code LOG2(number)} function applied to - * BigDecimal values. */ - public static @Nullable Double log2(BigDecimal number) { - return log2(number.doubleValue()); + public static @Nullable Double log(BigDecimal number, BigDecimal number2, int nullFlag) { + if (nullFlag == 1 && number.doubleValue() <= 0) { + return null; + } + return Math.log(number.doubleValue()) / Math.log(number2.doubleValue()); } // MOD diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 6996ecb3160a..06462b3c21ca 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -2199,13 +2199,21 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding * @see SqlStdOperatorTable#LN * @see SqlStdOperatorTable#LOG10 */ - @LibraryOperator(libraries = {BIG_QUERY}) + @LibraryOperator(libraries = {BIG_QUERY, POSTGRESQL}) public static final SqlFunction LOG = SqlBasicFunction.create("LOG", ReturnTypes.DOUBLE_NULLABLE, OperandTypes.NUMERIC_OPTIONAL_NUMERIC, SqlFunctionCategory.NUMERIC); + /** The "LOG(numeric, numeric1)" function. Returns the base numeric1 logarithm of numeric. */ + @LibraryOperator(libraries = {MYSQL, SPARK}) + public static final SqlFunction LOG_MS = + SqlBasicFunction.create("LOG", + ReturnTypes.DOUBLE_FORCE_NULLABLE, + OperandTypes.NUMERIC_OPTIONAL_NUMERIC, + SqlFunctionCategory.NUMERIC); + /** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */ @LibraryOperator(libraries = {MYSQL, SPARK}) public static final SqlFunction LOG2 = @@ -2214,6 +2222,7 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding OperandTypes.NUMERIC, SqlFunctionCategory.NUMERIC); + @LibraryOperator(libraries = {BIG_QUERY, SPARK}) public static final SqlFunction POW = SqlStdOperatorTable.POWER.withName("POW"); diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 2e72df1be651..f70f1ac8c49f 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -512,8 +512,7 @@ public enum BuiltInMethod { SAFE_DIVIDE(SqlFunctions.class, "safeDivide", double.class, double.class), SAFE_MULTIPLY(SqlFunctions.class, "safeMultiply", double.class, double.class), SAFE_SUBTRACT(SqlFunctions.class, "safeSubtract", double.class, double.class), - LOG(SqlFunctions.class, "log", long.class, long.class), - LOG2(SqlFunctions.class, "log2", long.class), + LOG(SqlFunctions.class, "log", long.class, long.class, int.class), SEC(SqlFunctions.class, "sec", double.class), SECH(SqlFunctions.class, "sech", double.class), SIGN(SqlFunctions.class, "sign", long.class), diff --git a/site/_docs/reference.md b/site/_docs/reference.md index e9d175a259f0..85af1f5ffc29 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -2789,7 +2789,8 @@ In the following: | f s | LEN(string) | Equivalent to `CHAR_LENGTH(string)` | b f s | LENGTH(string) | Equivalent to `CHAR_LENGTH(string)` | h s | LEVENSHTEIN(string1, string2) | Returns the Levenshtein distance between *string1* and *string2* -| b | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present +| b p | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present +| m s | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present | m s | LOG2(numeric) | Returns the base 2 logarithm of *numeric* | b o s | LPAD(string, length [, pattern ]) | Returns a string or bytes value that consists of *string* prepended to *length* with *pattern* | b | TO_BASE32(string) | Converts the *string* to base-32 encoded form and returns an encoded string diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index d0e529944d2b..e6964127122b 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -6612,25 +6612,27 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { .setFor(SqlLibraryOperators.LOG, VmName.EXPAND); f0.checkFails("^log(100, 10)^", "No match found for function signature LOG\\(, \\)", false); - final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.BIG_QUERY); - f.checkScalarApprox("log(10, 10)", "DOUBLE NOT NULL", - isWithin(1.0, 0.000001)); - f.checkScalarApprox("log(64, 8)", "DOUBLE NOT NULL", - isWithin(2.0, 0.000001)); - f.checkScalarApprox("log(27,3)", "DOUBLE NOT NULL", - isWithin(3.0, 0.000001)); - f.checkScalarApprox("log(100, 10)", "DOUBLE NOT NULL", - isWithin(2.0, 0.000001)); - f.checkScalarApprox("log(10, 100)", "DOUBLE NOT NULL", - isWithin(0.5, 0.000001)); - f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE NOT NULL", - isWithin(7.0, 0.000001)); - f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE NOT NULL", - isWithin(9.0, 0.000001)); - f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE NOT NULL", - isWithin(-2.0, 0.000001)); - f.checkNull("log(cast(null as real), 10)"); - f.checkNull("log(10, cast(null as real))"); + final Consumer consumer = f -> { + f.checkScalarApprox("log(10, 10)", "DOUBLE NOT NULL", + isWithin(1.0, 0.000001)); + f.checkScalarApprox("log(64, 8)", "DOUBLE NOT NULL", + isWithin(2.0, 0.000001)); + f.checkScalarApprox("log(27,3)", "DOUBLE NOT NULL", + isWithin(3.0, 0.000001)); + f.checkScalarApprox("log(100, 10)", "DOUBLE NOT NULL", + isWithin(2.0, 0.000001)); + f.checkScalarApprox("log(10, 100)", "DOUBLE NOT NULL", + isWithin(0.5, 0.000001)); + f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE NOT NULL", + isWithin(7.0, 0.000001)); + f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE NOT NULL", + isWithin(9.0, 0.000001)); + f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE NOT NULL", + isWithin(-2.0, 0.000001)); + f.checkNull("log(cast(null as real), 10)"); + f.checkNull("log(10, cast(null as real))"); + }; + f0.forEachLibrary(list(SqlLibrary.BIG_QUERY, SqlLibrary.POSTGRESQL), consumer); } /** Test case for @@ -6670,6 +6672,43 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f0.forEachLibrary(list(SqlLibrary.MYSQL, SqlLibrary.SPARK), consumer); } + /** Test case for + * [CALCITE-6259] + * Add LOG function (enabled in MYSQL, Spark library). */ + @Test void testLogMSFunc() { + final SqlOperatorFixture f0 = Fixtures.forOperators(true); + f0.checkFails("^log(100, 10)^", + "No match found for function signature LOG\\(, \\)", false); + f0.setFor(SqlLibraryOperators.LOG_MS); + final Consumer consumer = f -> { + f.checkScalarApprox("log(10, 10)", "DOUBLE", + isWithin(1.0, 0.000001)); + f.checkScalarApprox("log(64, 8)", "DOUBLE", + isWithin(2.0, 0.000001)); + f.checkScalarApprox("log(27,3)", "DOUBLE", + isWithin(3.0, 0.000001)); + f.checkScalarApprox("log(100, 10)", "DOUBLE", + isWithin(2.0, 0.000001)); + f.checkScalarApprox("log(10, 100)", "DOUBLE", + isWithin(0.5, 0.000001)); + f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE", + isWithin(7.0, 0.000001)); + f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE", + isWithin(9.0, 0.000001)); + f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE", + isWithin(-2.0, 0.000001)); + f.checkNull("log(cast(null as real), 10)"); + f.checkNull("log(10, cast(null as real))"); + f.checkNull("log(0, 2)"); + f.checkNull("log(0,-2)"); + f.checkNull("log(0, +0.0)"); + f.checkNull("log(0, 0.0)"); + f.checkNull("log(null)"); + f.checkNull("log(cast(null as real))"); + }; + f0.forEachLibrary(list(SqlLibrary.MYSQL, SqlLibrary.SPARK), consumer); + } + @Test void testRandFunc() { final SqlOperatorFixture f = fixture(); f.setFor(SqlStdOperatorTable.RAND, VmName.EXPAND);