Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6325] Add LOG function (enabled in Mysql and Spark library) #3789

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG2;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOG_MYSQL;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_CONCAT;
Expand Down Expand Up @@ -649,12 +650,14 @@ Builder populate() {
defineMethod(POWER, BuiltInMethod.POWER.method, NullPolicy.STRICT);
defineMethod(POWER_PG, BuiltInMethod.POWER_PG.method, NullPolicy.STRICT);
defineMethod(ABS, BuiltInMethod.ABS.method, NullPolicy.STRICT);
defineMethod(LOG2, BuiltInMethod.LOG2.method, NullPolicy.STRICT);

map.put(LN, new LogImplementor());
map.put(LOG, new LogImplementor());
map.put(LOG10, new LogImplementor());

map.put(LOG_MYSQL, new LogMysqlSparkImplementor());
map.put(LOG2, new LogMysqlSparkImplementor());

defineReflective(RAND, BuiltInMethod.RAND.method,
BuiltInMethod.RAND_SEED.method);
defineReflective(RAND_INTEGER, BuiltInMethod.RAND_INTEGER.method,
Expand Down Expand Up @@ -4210,13 +4213,51 @@ private static List<Expression> args(RexCall call,
switch (call.getOperator().getName()) {
case "LOG":
if (argValueList.size() == 2) {
return list.append(argValueList.get(1));
return list.append(argValueList.get(1)).append(Expressions.constant(0));
}
// fall through
case "LN":
return list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(0));
case "LOG10":
return list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(0));
default:
throw new AssertionError("Operator not found: " + call.getOperator());
}
}
}

/** Implementor for the {@code LN}, {@code LOG}, {@code LOG2} and {@code LOG10} operators
* on Mysql and Spark library
*
* <p>Handles all logarithm functions using log rules to determine the
* appropriate base (i.e. base e for LN).
*/
private static class LogMysqlSparkImplementor extends AbstractRexCallImplementor {
LogMysqlSparkImplementor() {
super("log", NullPolicy.STRICT, true);
}

@Override Expression implementSafe(final RexToLixTranslator translator,
final RexCall call, final List<Expression> argValueList) {
return Expressions.call(BuiltInMethod.LOG.method, args(call, argValueList));
}

private static List<Expression> args(RexCall call,
List<Expression> argValueList) {
Expression operand0 = argValueList.get(0);
final Expressions.FluentList<Expression> list = Expressions.list(operand0);
switch (call.getOperator().getName()) {
case "LOG":
if (argValueList.size() == 2) {
return list.append(argValueList.get(1)).append(Expressions.constant(1));
}
// fall through
case "LN":
return list.append(Expressions.constant(Math.exp(1)));
return list.append(Expressions.constant(Math.exp(1))).append(Expressions.constant(1));
case "LOG2":
return list.append(Expressions.constant(2)).append(Expressions.constant(1));
case "LOG10":
return list.append(Expressions.constant(BigDecimal.TEN));
return list.append(Expressions.constant(BigDecimal.TEN)).append(Expressions.constant(1));
default:
throw new AssertionError("Operator not found: " + call.getOperator());
}
Expand Down
39 changes: 20 additions & 19 deletions core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2789,36 +2789,37 @@ public static double power(BigDecimal b0, BigDecimal b1) {
// LN, LOG, LOG10, LOG2

/** SQL {@code LOG(number, number2)} function applied to double values. */
public static double log(double d0, double d1) {
return Math.log(d0) / Math.log(d1);
public static @Nullable Double log(double number, double number2, int nullFlag) {
if (nullFlag == 1 && number <= 0) {
return null;
}
return Math.log(number) / Math.log(number2);
}

/** SQL {@code LOG(number, number2)} function applied to
* double and BigDecimal values. */
public static double log(double d0, BigDecimal d1) {
return Math.log(d0) / Math.log(d1.doubleValue());
public static @Nullable Double log(double number, BigDecimal number2, int nullFlag) {
if (nullFlag == 1 && number <= 0) {
return null;
}
return Math.log(number) / Math.log(number2.doubleValue());
}

/** SQL {@code LOG(number, number2)} function applied to
* BigDecimal and double values. */
public static double log(BigDecimal d0, double d1) {
return Math.log(d0.doubleValue()) / Math.log(d1);
public static @Nullable Double log(BigDecimal number, double number2, int nullFlag) {
if (nullFlag == 1 && number.doubleValue() <= 0) {
return null;
}
return Math.log(number.doubleValue()) / Math.log(number2);
}

/** SQL {@code LOG(number, number2)} function applied to double values. */
public static double log(BigDecimal d0, BigDecimal d1) {
return Math.log(d0.doubleValue()) / Math.log(d1.doubleValue());
}

/** SQL {@code LOG2(number)} function applied to double values. */
public static @Nullable Double log2(double number) {
return (number <= 0) ? null : log(number, 2);
}

/** SQL {@code LOG2(number)} function applied to
* BigDecimal values. */
public static @Nullable Double log2(BigDecimal number) {
return log2(number.doubleValue());
public static @Nullable Double log(BigDecimal number, BigDecimal number2, int nullFlag) {
if (nullFlag == 1 && number.doubleValue() <= 0) {
return null;
}
return Math.log(number.doubleValue()) / Math.log(number2.doubleValue());
}

// MOD
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/SqlKind.java
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ public enum SqlKind {
/** {@code LEAST} function (Oracle). */
LEAST,

/** {@code LOG} function. (Mysql, Spark). */
LOG,

/** {@code DATE_ADD} function (BigQuery Semantics). */
DATE_ADD,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2224,6 +2224,13 @@ private static RelDataType deriveTypeMapFromEntries(SqlOperatorBinding opBinding
OperandTypes.NUMERIC_OPTIONAL_NUMERIC,
SqlFunctionCategory.NUMERIC);

/** The "LOG(numeric, numeric1)" function. Returns the base numeric1 logarithm of numeric. */
@LibraryOperator(libraries = {MYSQL, SPARK})
public static final SqlFunction LOG_MYSQL =
SqlBasicFunction.create(SqlKind.LOG,
ReturnTypes.DOUBLE_FORCE_NULLABLE,
OperandTypes.NUMERIC_OPTIONAL_NUMERIC);

/** The "LOG2(numeric)" function. Returns the base 2 logarithm of numeric. */
@LibraryOperator(libraries = {MYSQL, SPARK})
public static final SqlFunction LOG2 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ public enum BuiltInMethod {
SAFE_DIVIDE(SqlFunctions.class, "safeDivide", double.class, double.class),
SAFE_MULTIPLY(SqlFunctions.class, "safeMultiply", double.class, double.class),
SAFE_SUBTRACT(SqlFunctions.class, "safeSubtract", double.class, double.class),
LOG(SqlFunctions.class, "log", long.class, long.class),
LOG2(SqlFunctions.class, "log2", long.class),
LOG(SqlFunctions.class, "log", long.class, long.class, int.class),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this requires an enum instead of a numeric flag?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanclary I'm sorry that it took so long to improve it. If you have time, please review this PR.

SEC(SqlFunctions.class, "sec", double.class),
SECH(SqlFunctions.class, "sech", double.class),
SIGN(SqlFunctions.class, "sign", long.class),
Expand Down
1 change: 1 addition & 0 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2790,6 +2790,7 @@ In the following:
| b f s | LENGTH(string) | Equivalent to `CHAR_LENGTH(string)`
| h s | LEVENSHTEIN(string1, string2) | Returns the Levenshtein distance between *string1* and *string2*
| b | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present
| m s | LOG(numeric1 [, numeric2 ]) | Returns the logarithm of *numeric1* to base *numeric2*, or base e if *numeric2* is not present, or null if *numeric1* is 0
| m s | LOG2(numeric) | Returns the base 2 logarithm of *numeric*
| b o s | LPAD(string, length [, pattern ]) | Returns a string or bytes value that consists of *string* prepended to *length* with *pattern*
| b | TO_BASE32(string) | Converts the *string* to base-32 encoded form and returns an encoded string
Expand Down
37 changes: 37 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6978,6 +6978,43 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
f0.forEachLibrary(list(SqlLibrary.MYSQL, SqlLibrary.SPARK), consumer);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6259">[CALCITE-6259]
* Add LOG function (enabled in MYSQL, Spark library)</a>. */
@Test void testLogMysqlSparkFunc() {
final SqlOperatorFixture f0 = fixture();
f0.checkFails("^log(100, 10)^",
"No match found for function signature LOG\\(<NUMERIC>, <NUMERIC>\\)", false);
f0.setFor(SqlLibraryOperators.LOG_MYSQL);
final Consumer<SqlOperatorFixture> consumer = f -> {
f.checkScalarApprox("log(10, 10)", "DOUBLE",
isWithin(1.0, 0.000001));
f.checkScalarApprox("log(64, 8)", "DOUBLE",
isWithin(2.0, 0.000001));
f.checkScalarApprox("log(27,3)", "DOUBLE",
isWithin(3.0, 0.000001));
f.checkScalarApprox("log(100, 10)", "DOUBLE",
isWithin(2.0, 0.000001));
f.checkScalarApprox("log(10, 100)", "DOUBLE",
isWithin(0.5, 0.000001));
f.checkScalarApprox("log(cast(10e6 as double), 10)", "DOUBLE",
isWithin(7.0, 0.000001));
f.checkScalarApprox("log(cast(10e8 as float), 10)", "DOUBLE",
isWithin(9.0, 0.000001));
f.checkScalarApprox("log(cast(10e-3 as real), 10)", "DOUBLE",
isWithin(-2.0, 0.000001));
f.checkNull("log(cast(null as real), 10)");
f.checkNull("log(10, cast(null as real))");
f.checkNull("log(0, 2)");
f.checkNull("log(0,-2)");
f.checkNull("log(0, +0.0)");
f.checkNull("log(0, 0.0)");
f.checkNull("log(null)");
f.checkNull("log(cast(null as real))");
};
f0.forEachLibrary(list(SqlLibrary.MYSQL, SqlLibrary.SPARK), consumer);
}

@Test void testRandFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.RAND, VmName.EXPAND);
Expand Down
Loading