Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6312] Add LOG function (enabled in PostgreSQL library) #3839

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.apache.calcite.sql.fun.SqlItemOperator;
import org.apache.calcite.sql.fun.SqlJsonArrayAggAggFunction;
import org.apache.calcite.sql.fun.SqlJsonObjectAggAggFunction;
import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlQuantifyOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlTrimFunction;
Expand All @@ -75,6 +76,7 @@
import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction;
import org.apache.calcite.sql.validate.SqlUserDefinedTableMacro;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -216,6 +218,7 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG_MYSQL;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG_POSTGRES;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_CONCAT;
Expand Down Expand Up @@ -654,12 +657,13 @@ Builder populate() {
defineMethod(POWER_PG, BuiltInMethod.POWER_PG.method, NullPolicy.STRICT);
defineMethod(ABS, BuiltInMethod.ABS.method, NullPolicy.STRICT);

map.put(LN, new LogImplementor());
map.put(LOG, new LogImplementor());
map.put(LOG10, new LogImplementor());
map.put(LN, new LogImplementor(SqlLibrary.BIG_QUERY));
map.put(LOG, new LogImplementor(SqlLibrary.BIG_QUERY));
map.put(LOG10, new LogImplementor(SqlLibrary.BIG_QUERY));

map.put(LOG_MYSQL, new LogMysqlImplementor());
map.put(LOG2, new LogMysqlImplementor());
map.put(LOG_POSTGRES, new LogImplementor(SqlLibrary.POSTGRESQL));
map.put(LOG_MYSQL, new LogImplementor(SqlLibrary.MYSQL));
map.put(LOG2, new LogImplementor(SqlLibrary.MYSQL));

defineReflective(RAND, BuiltInMethod.RAND.method,
BuiltInMethod.RAND_SEED.method);
Expand Down Expand Up @@ -4202,67 +4206,57 @@ private static class LogicalNotImplementor extends AbstractRexCallImplementor {
* appropriate base (i.e. base e for LN).
*/
private static class LogImplementor extends AbstractRexCallImplementor {
LogImplementor() {
private final SqlLibrary library;
LogImplementor(SqlLibrary library) {
super("log", NullPolicy.STRICT, true);
this.library = library;
}

@Override Expression implementSafe(final RexToLixTranslator translator,
final RexCall call, final List<Expression> argValueList) {
return Expressions.call(BuiltInMethod.LOG.method, args(call, argValueList));
}

return Expressions.
call(BuiltInMethod.LOG.method, args(call, argValueList, library));
}

/**
* This method is used to handle the implementation of different log functions.
* It generates the corresponding expression list based on the input function name
* and argument list.
*
* @param call The RexCall that contains the function call information.
* @param argValueList The list of argument expressions.
* @param library The SQL library that the function belongs to.
* @return A list of expressions that represents the implementation of the log function.
*/
private static List<Expression> args(RexCall call,
List<Expression> argValueList) {
Expression operand0 = argValueList.get(0);
final Expressions.FluentList<Expression> list = Expressions.list(operand0);
switch (call.getOperator().getName()) {
case "LOG":
if (argValueList.size() == 2) {
return list.append(argValueList.get(1)).append(Expressions.constant(0));
}
// fall through
case "LN":
return list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(0));
case "LOG10":
return list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(0));
default:
throw new AssertionError("Operator not found: " + call.getOperator());
List<Expression> argValueList, SqlLibrary library) {
Pair<Expression, Expression> operands;
Expression operand0;
Expression operand1;
if (argValueList.size() == 1) {
operands = library == SqlLibrary.POSTGRESQL
? Pair.of(argValueList.get(0), Expressions.constant(BigDecimal.TEN))
: Pair.of(argValueList.get(0), Expressions.constant(Math.exp(1)));
} else {
operands = library == SqlLibrary.BIG_QUERY
? Pair.of(argValueList.get(0), argValueList.get(1))
: Pair.of(argValueList.get(1), argValueList.get(0));
}
}
}

/** Implementor for the {@code LN}, {@code LOG}, {@code LOG2} and {@code LOG10} operators
* on Mysql and Spark library
*
* <p>Handles all logarithm functions using log rules to determine the
* appropriate base (i.e. base e for LN).
*/
private static class LogMysqlImplementor extends AbstractRexCallImplementor {
LogMysqlImplementor() {
super("log", NullPolicy.STRICT, true);
}

@Override Expression implementSafe(final RexToLixTranslator translator,
final RexCall call, final List<Expression> argValueList) {
return Expressions.call(BuiltInMethod.LOG.method, args(call, argValueList));
}

private static List<Expression> args(RexCall call,
List<Expression> argValueList) {
Expression operand0 = argValueList.get(0);
operand0 = operands.left;
operand1 = operands.right;
boolean nonPositiveIsNull = library == SqlLibrary.MYSQL ? true : false;
final Expressions.FluentList<Expression> list = Expressions.list(operand0);
switch (call.getOperator().getName()) {
case "LOG":
if (argValueList.size() == 2) {
return list.append(argValueList.get(1)).append(Expressions.constant(1));
}
// fall through
return list.append(operand1).append(Expressions.constant(nonPositiveIsNull));
case "LN":
return list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(1));
return list.append(Expressions.constant(Math.exp(1)))
.append(Expressions.constant(nonPositiveIsNull));
case "LOG2":
return list.append(Expressions.constant(2)).append(Expressions.constant(1));
return list.append(Expressions.constant(2)).append(Expressions.constant(nonPositiveIsNull));
case "LOG10":
return list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(1));
return list.append(Expressions.constant(BigDecimal.TEN))
.append(Expressions.constant(nonPositiveIsNull));
default:
throw new AssertionError("Operator not found: " + call.getOperator());
}
Expand Down
63 changes: 44 additions & 19 deletions core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2785,41 +2785,66 @@ public static double power(BigDecimal b0, BigDecimal b1) {
return Math.pow(b0.doubleValue(), b1.doubleValue());
}


// LN, LOG, LOG10, LOG2

/** SQL {@code LOG(number, number2)} function applied to double values. */
public static @Nullable Double log(double number, double number2, int nullFlag) {
if (nullFlag == 1 && number <= 0) {
/**
* SQL {@code LOG(number, base)} function applied to double values.
*
* @param nonPositiveIsNull if true return null for non-positive values
*/
public static @Nullable Double log(double number, double base, boolean nonPositiveIsNull) {
if (nonPositiveIsNull && number <= 0) {
return null;
}
return Math.log(number) / Math.log(number2);
if (number <= 0 || base <= 0) {
throw new IllegalArgumentException("Cannot take logarithm of zero or negative number");
}
return Math.log(number) / Math.log(base);
}

/** SQL {@code LOG(number, number2)} function applied to
* double and BigDecimal values. */
public static @Nullable Double log(double number, BigDecimal number2, int nullFlag) {
if (nullFlag == 1 && number <= 0) {
/** SQL {@code LOG(number, base)} function applied to
* double and BigDecimal values.
*
* @param nonPositiveIsNull if true return null for non-positive values
*/
public static @Nullable Double log(double number, BigDecimal base, boolean nonPositiveIsNull) {
if (nonPositiveIsNull && number <= 0) {
return null;
}
return Math.log(number) / Math.log(number2.doubleValue());
if (number <= 0 || base.doubleValue() <= 0) {
throw new IllegalArgumentException("Cannot take logarithm of zero or negative number");
}
return Math.log(number) / Math.log(base.doubleValue());
}

/** SQL {@code LOG(number, number2)} function applied to
* BigDecimal and double values. */
public static @Nullable Double log(BigDecimal number, double number2, int nullFlag) {
if (nullFlag == 1 && number.doubleValue() <= 0) {
/** SQL {@code LOG(number, base)} function applied to
* BigDecimal and double values.
*
* @param nonPositiveIsNull if true return null for non-positive values
*/
public static @Nullable Double log(BigDecimal number, double base, Boolean nonPositiveIsNull) {
if (nonPositiveIsNull && number.doubleValue() <= 0) {
return null;
}
return Math.log(number.doubleValue()) / Math.log(number2);
if (number.doubleValue() <= 0 || base <= 0) {
throw new IllegalArgumentException("Cannot take logarithm of zero or negative number");
}
return Math.log(number.doubleValue()) / Math.log(base);
}

/** SQL {@code LOG(number, number2)} function applied to double values. */
public static @Nullable Double log(BigDecimal number, BigDecimal number2, int nullFlag) {
if (nullFlag == 1 && number.doubleValue() <= 0) {
/** SQL {@code LOG(number, base)} function applied to double values.
*
* @param nonPositiveIsNull if true return null for non-positive values
*/
public static @Nullable Double log(BigDecimal number, BigDecimal base,
Boolean nonPositiveIsNull) {
if (nonPositiveIsNull && number.doubleValue() <= 0) {
return null;
}
return Math.log(number.doubleValue()) / Math.log(number2.doubleValue());
if (number.doubleValue() <= 0 || base.doubleValue() <= 0) {
throw new IllegalArgumentException("Cannot take logarithm of zero or negative number");
}
return Math.log(number.doubleValue()) / Math.log(base.doubleValue());
}

// MOD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2237,13 +2237,23 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding
OperandTypes.NUMERIC_OPTIONAL_NUMERIC,
SqlFunctionCategory.NUMERIC);

/** The "LOG(numeric, numeric1)" function. Returns the base numeric1 logarithm of numeric. */
/** The "LOG(numeric1 [, numeric2 ]) " function. Returns the logarithm of numeric2
* to base numeric1.*/
@LibraryOperator(libraries = {MYSQL, SPARK})
public static final SqlFunction LOG_MYSQL =
SqlBasicFunction.create(SqlKind.LOG,
ReturnTypes.DOUBLE_FORCE_NULLABLE,
OperandTypes.NUMERIC_OPTIONAL_NUMERIC);

/** The "LOG(numeric1 [, numeric2 ]) " function. Returns the logarithm of numeric2
* to base numeric1.*/
@LibraryOperator(libraries = {POSTGRESQL}, exceptLibraries = {REDSHIFT})
public static final SqlFunction LOG_POSTGRES =
new SqlBasicFunction("LOG", SqlKind.LOG,
SqlSyntax.FUNCTION, true, ReturnTypes.DOUBLE_NULLABLE, null,
OperandHandlers.DEFAULT, OperandTypes.NUMERIC_OPTIONAL_NUMERIC, 0,
SqlFunctionCategory.NUMERIC, call -> SqlMonotonicity.NOT_MONOTONIC, false) { };

/** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */
@LibraryOperator(libraries = {MYSQL, SPARK})
public static final SqlFunction LOG2 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ public enum BuiltInMethod {
SAFE_DIVIDE(SqlFunctions.class, "safeDivide", double.class, double.class),
SAFE_MULTIPLY(SqlFunctions.class, "safeMultiply", double.class, double.class),
SAFE_SUBTRACT(SqlFunctions.class, "safeSubtract", double.class, double.class),
LOG(SqlFunctions.class, "log", long.class, long.class, int.class),
LOG(SqlFunctions.class, "log", long.class, long.class, boolean.class),
SEC(SqlFunctions.class, "sec", double.class),
SECH(SqlFunctions.class, "sech", double.class),
SIGN(SqlFunctions.class, "sign", long.class),
Expand Down
5 changes: 3 additions & 2 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2805,8 +2805,9 @@ In the following:
| f s | LEN(string) | Equivalent to `CHAR_LENGTH(string)`
| b f s | LENGTH(string) | Equivalent to `CHAR_LENGTH(string)`
| h s | LEVENSHTEIN(string1, string2) | Returns the Levenshtein distance between *string1* and *string2*
| b | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present
| m s | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present, or null if *numeric1* is 0 or negative
| b | LOG(numeric1 [, base ]) | Returns the logarithm of *numeric1* to base *base*, or base e if *base* is not present, or error if *numeric1* is 0 or negative
| m s | LOG([, base ], numeric1) | Returns the logarithm of *numeric1* to base *base*, or base e if *base* is not present, or null if *numeric1* is 0 or negative
| p | LOG([, base ], numeric1 ) | Returns the logarithm of *numeric1* to base *base*, or base 10 if *numeric1* is not present, or error if *numeric1* is 0 or negative
| m s | LOG2(numeric) | Returns the base 2 logarithm of *numeric*
| b o p r s | LPAD(string, length [, pattern ]) | Returns a string or bytes value that consists of *string* prepended to *length* with *pattern*
| b | TO_BASE32(string) | Converts the *string* to base-32 encoded form and returns an encoded string
Expand Down
62 changes: 52 additions & 10 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6959,14 +6959,22 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
isWithin(2.0, 0.000001));
f.checkScalarApprox("log(10, 100)", "DOUBLE NOT NULL",
isWithin(0.5, 0.000001));
f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE NOT NULL",
f.checkScalarApprox("log(cast(1e7 as double), 10)", "DOUBLE NOT NULL",
isWithin(7.0, 0.000001));
f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE NOT NULL",
isWithin(9.0, 0.000001));
f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE NOT NULL",
isWithin(-2.0, 0.000001));
f.checkScalarApprox("log(10)", "DOUBLE NOT NULL",
isWithin(2.302585092994046, 0.000001));
f.checkNull("log(cast(null as real), 10)");
f.checkNull("log(10, cast(null as real))");
f.checkFails("log(0)",
"Cannot take logarithm of zero or negative number", true);
f.checkFails("log(0, 64)",
caicancai marked this conversation as resolved.
Show resolved Hide resolved
"Cannot take logarithm of zero or negative number", true);
f.checkFails("log(64, 0)",
"Cannot take logarithm of zero or negative number", true);
}

/** Test case for
Expand Down Expand Up @@ -7017,24 +7025,27 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
final Consumer<SqlOperatorFixture> consumer = f -> {
f.checkScalarApprox("log(10, 10)", "DOUBLE",
isWithin(1.0, 0.000001));
f.checkScalarApprox("log(64, 8)", "DOUBLE",
f.checkScalarApprox("log(8, 64)", "DOUBLE",
isWithin(2.0, 0.000001));
f.checkScalarApprox("log(27,3)", "DOUBLE",
f.checkScalarApprox("log(3,27)", "DOUBLE",
isWithin(3.0, 0.000001));
f.checkScalarApprox("log(100, 10)", "DOUBLE",
isWithin(2.0, 0.000001));
f.checkScalarApprox("log(10, 100)", "DOUBLE",
isWithin(2.0, 0.000001));
f.checkScalarApprox("log(100, 10)", "DOUBLE",
isWithin(0.5, 0.000001));
f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE",
f.checkScalarApprox("log(10, cast(1e7 as double))", "DOUBLE",
isWithin(7.0, 0.000001));
f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE",
f.checkScalarApprox("log(10, cast(1e9 as float))", "DOUBLE",
isWithin(9.0, 0.000001));
f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE",
// real type is equivalent to double type
f.checkScalarApprox("log(10, cast(1e-2 as real))", "DOUBLE",
isWithin(-2.0, 0.000001));
f.checkScalarApprox("log(10)", "DOUBLE",
isWithin(2.302585092994046, 0.000001));
f.checkNull("log(cast(null as real), 10)");
f.checkNull("log(10, cast(null as real))");
f.checkNull("log(0, 2)");
f.checkNull("log(0,-2)");
f.checkNull("log(2, 0)");
f.checkNull("log(-2,0)");
f.checkNull("log(0, +0.0)");
f.checkNull("log(0, 0.0)");
f.checkNull("log(null)");
Expand All @@ -7045,6 +7056,37 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
f0.forEachLibrary(list(SqlLibrary.MYSQL, SqlLibrary.SPARK), consumer);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6312">[CALCITE-6312]
* Add LOG function (enabled in PostgreSQL library)</a>. */
@Test void testPostgresLogFunc() {
final SqlOperatorFixture f0 = fixture()
.setFor(SqlLibraryOperators.LOG_POSTGRES, VmName.EXPAND);
f0.checkFails("^log(100, 10)^",
"No match found for function signature LOG\\(<NUMERIC>, <NUMERIC>\\)", false);
final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.POSTGRESQL);
f.checkScalar("log(10, 10)", 1.0,
"DOUBLE NOT NULL");
f.checkScalar("log(8, 64)", 2.0,
"DOUBLE NOT NULL");
f.checkScalar("log(10, 100)", 2.0,
"DOUBLE NOT NULL");
f.checkScalar("log(100, 10)", 0.5,
"DOUBLE NOT NULL");
f.checkScalar("log(10, cast(1e7 as double))", 7.0,
"DOUBLE NOT NULL");
f.checkScalar("log(10)", 1.0,
"DOUBLE NOT NULL");
f.checkNull("log(cast(null as real), 10)");
f.checkNull("log(10, cast(null as real))");
f.checkFails("log(0)",
"Cannot take logarithm of zero or negative number", true);
f.checkFails("log(0, 64)",
"Cannot take logarithm of zero or negative number", true);
f.checkFails("log(64, 0)",
"Cannot take logarithm of zero or negative number", true);
}

@Test void testRandFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.RAND, VmName.EXPAND);
Expand Down
Loading