diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index ef755aae3fb07..858ab923490fc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -29,33 +29,62 @@ *

* The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + * The mapping between {@link DataType} and the corresponding JVM type is defined below. *

* IMPORTANT: the default implementation of {@link #produceResult} throws - * {@link UnsupportedOperationException}. Users can choose to override this method, or implement - * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters - * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java - * reflection and will also provide better performance in general, due to optimizations such as - * codegen, removal of Java boxing, etc. - * + * {@link UnsupportedOperationException}. Users must choose to either override this method, or + * implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters + * instead of a {@link InternalRow}. The magic method approach is generally recommended because it + * provides better performance over the default {@link #produceResult}, due to optimizations such + * as whole-stage codegen, elimination of Java boxing, etc. + *

+ * In addition, for stateless Java functions, users can optionally define the + * {@link #MAGIC_METHOD_NAME} as a static method, which further avoids certain runtime costs such + * as Java dynamic dispatch. + *

* For example, a scalar UDF for adding two integers can be defined as follow with the magic * method approach: * *

  *   public class IntegerAdd implements{@code ScalarFunction} {
+ *     public DataType[] inputTypes() {
+ *       return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
+ *     }
  *     public int invoke(int left, int right) {
  *       return left + right;
  *     }
  *   }
  * 
- * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over - * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by - * first converting the actual input SQL data types to their corresponding Java types following - * the mapping defined below, and then checking if there is a matching method from all the - * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and - * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}. + * In the above, since {@link #MAGIC_METHOD_NAME} is defined, and also that it has + * matching parameter types and return type, Spark will use it to evaluate inputs. + *

+ * As another example, in the following: + *

+ *   public class IntegerAdd implements{@code ScalarFunction} {
+ *     public DataType[] inputTypes() {
+ *       return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
+ *     }
+ *     public static int invoke(int left, int right) {
+ *       return left + right;
+ *     }
+ *     public Integer produceResult(InternalRow input) {
+ *       return input.getInt(0) + input.getInt(1);
+ *     }
+ *   }
+ * 
+ * + * the class defines both the magic method and the {@link #produceResult}, and Spark will use + * {@link #MAGIC_METHOD_NAME} over the {@link #produceResult(InternalRow)} as it takes higher + * precedence. Also note that the magic method is annotated as a static method in this case. + *

+ * Resolution on magic method is done during query analysis, where Spark looks up the magic + * method by first converting the actual input SQL data types to their corresponding Java types + * following the mapping defined below, and then checking if there is a matching method from all the + * declared methods in the UDF class, using method name and the Java types. *

- * The following are the mapping from {@link DataType SQL data type} to Java type through - * the magic method approach: + * The following are the mapping from {@link DataType SQL data type} to Java type which is used + * by Spark to infer parameter types for the magic methods as well as return value type for + * {@link #produceResult}: *

* - * @param the JVM type of result values + * @param the JVM type of result values, MUST be consistent with the {@link DataType} + * returned via {@link #resultType()}, according to the mapping above. */ public interface ScalarFunction extends BoundFunction { String MAGIC_METHOD_NAME = "invoke"; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b3310635cdcb6..757778b66c345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Method +import java.lang.reflect.{Method, Modifier} import java.util import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean @@ -2181,6 +2181,9 @@ class Analyzer(override val catalogManager: CatalogManager) // match the input type through `BoundFunction.inputTypes`. val argClasses = inputType.fields.map(_.dataType) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(m) if Modifier.isStatic(m.getModifiers) => + StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), + MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable) case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), diff --git a/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt b/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..564742f653e57 --- /dev/null +++ b/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt @@ -0,0 +1,44 @@ +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------ +native_long_add 17789 18405 580 28.1 35.6 1.0X +java_long_add_default 85058 87073 NaN 5.9 170.1 0.2X +java_long_add_magic 20262 20641 352 24.7 40.5 0.9X +java_long_add_static_magic 19458 19524 105 25.7 38.9 0.9X +scala_long_add_default 85892 86496 560 5.8 171.8 0.2X +scala_long_add_magic 20164 20330 212 24.8 40.3 0.9X + +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 18290 18467 157 27.3 36.6 1.0X +java_long_add_default 82415 82687 270 6.1 164.8 0.2X +java_long_add_magic 19941 20032 85 25.1 39.9 0.9X +java_long_add_static_magic 17861 17940 92 28.0 35.7 1.0X +scala_long_add_default 83800 85639 NaN 6.0 167.6 0.2X +scala_long_add_magic 20103 20123 18 24.9 40.2 0.9X + +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 46039 46199 162 10.9 92.1 1.0X +java_long_add_default 113199 113773 720 4.4 226.4 0.4X +java_long_add_magic 158252 159419 1075 3.2 316.5 0.3X +java_long_add_static_magic 157162 157676 516 3.2 314.3 0.3X +scala_long_add_default 112363 113264 1503 4.4 224.7 0.4X +scala_long_add_magic 158122 159010 835 3.2 316.2 0.3X + +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 42685 42743 54 11.7 85.4 1.0X +java_long_add_default 92041 92236 202 5.4 184.1 0.5X +java_long_add_magic 148299 148722 397 3.4 296.6 0.3X +java_long_add_static_magic 140599 141064 442 3.6 281.2 0.3X +scala_long_add_default 91896 92980 959 5.4 183.8 0.5X +scala_long_add_magic 148031 148802 759 3.4 296.1 0.3X + diff --git a/sql/core/benchmarks/V2FunctionBenchmark-results.txt b/sql/core/benchmarks/V2FunctionBenchmark-results.txt new file mode 100644 index 0000000000000..2035aa3633b80 --- /dev/null +++ b/sql/core/benchmarks/V2FunctionBenchmark-results.txt @@ -0,0 +1,44 @@ +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------ +native_long_add 10559 11585 903 47.4 21.1 1.0X +java_long_add_default 78979 80089 987 6.3 158.0 0.1X +java_long_add_magic 14061 14326 305 35.6 28.1 0.8X +java_long_add_static_magic 11971 12150 242 41.8 23.9 0.9X +scala_long_add_default 77254 78565 1254 6.5 154.5 0.1X +scala_long_add_magic 13174 13232 51 38.0 26.3 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 10489 10665 162 47.7 21.0 1.0X +java_long_add_default 66636 68422 NaN 7.5 133.3 0.2X +java_long_add_magic 13504 14213 883 37.0 27.0 0.8X +java_long_add_static_magic 11726 11984 240 42.6 23.5 0.9X +scala_long_add_default 75906 76130 196 6.6 151.8 0.1X +scala_long_add_magic 14480 14770 261 34.5 29.0 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 39178 39548 323 12.8 78.4 1.0X +java_long_add_default 84756 85509 1092 5.9 169.5 0.5X +java_long_add_magic 199140 200801 1823 2.5 398.3 0.2X +java_long_add_static_magic 203500 207050 NaN 2.5 407.0 0.2X +scala_long_add_default 101180 101421 387 4.9 202.4 0.4X +scala_long_add_magic 193277 197006 1138 2.6 386.6 0.2X + +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 37064 37333 235 13.5 74.1 1.0X +java_long_add_default 104439 107802 NaN 4.8 208.9 0.4X +java_long_add_magic 212496 214321 NaN 2.4 425.0 0.2X +java_long_add_static_magic 239551 240619 1652 2.1 479.1 0.2X +scala_long_add_default 122413 123171 788 4.1 244.8 0.3X +scala_long_add_magic 215912 222715 NaN 2.3 431.8 0.2X + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java new file mode 100644 index 0000000000000..e2e7136d6f44c --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaLongAdd implements UnboundFunction { + private final ScalarFunction impl; + + public JavaLongAdd(ScalarFunction impl) { + this.impl = impl; + } + + @Override + public String name() { + return "long_add"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 2) { + throw new UnsupportedOperationException("Expect two arguments"); + } + StructField[] fields = inputType.fields(); + if (!(fields[0].dataType() instanceof LongType)) { + throw new UnsupportedOperationException("Expect first argument to be LongType"); + } + if (!(fields[1].dataType() instanceof LongType)) { + throw new UnsupportedOperationException("Expect second argument to be LongType"); + } + return impl; + } + + @Override + public String description() { + return "long_add"; + } + + private abstract static class JavaLongAddBase implements ScalarFunction { + private final boolean isResultNullable; + + JavaLongAddBase(boolean isResultNullable) { + this.isResultNullable = isResultNullable; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.LongType, DataTypes.LongType }; + } + + @Override + public DataType resultType() { + return DataTypes.LongType; + } + + @Override + public boolean isResultNullable() { + return isResultNullable; + } + } + + public static class JavaLongAddDefault extends JavaLongAddBase { + public JavaLongAddDefault(boolean isResultNullable) { + super(isResultNullable); + } + + @Override + public String name() { + return "long_add_default"; + } + + @Override + public Long produceResult(InternalRow input) { + return input.getLong(0) + input.getLong(1); + } + } + + public static class JavaLongAddMagic extends JavaLongAddBase { + public JavaLongAddMagic(boolean isResultNullable) { + super(isResultNullable); + } + + @Override + public String name() { + return "long_add_magic"; + } + + public long invoke(long left, long right) { + return left + right; + } + } + + public static class JavaLongAddStaticMagic extends JavaLongAddBase { + public JavaLongAddStaticMagic(boolean isResultNullable) { + super(isResultNullable); + } + + @Override + public String name() { + return "long_add_static_magic"; + } + + public static long invoke(long left, long right) { + return left + right; + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index 8b2d883a3703f..7cd010b9365bb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -58,7 +58,7 @@ public String description() { " strlen(string) -> int"; } - public static class JavaStrLenDefault implements ScalarFunction { + private abstract static class JavaStrLenBase implements ScalarFunction { @Override public DataType[] inputTypes() { return new DataType[] { DataTypes.StringType }; @@ -73,7 +73,9 @@ public DataType resultType() { public String name() { return "strlen"; } + } + public static class JavaStrLenDefault extends JavaStrLenBase { @Override public Integer produceResult(InternalRow input) { String str = input.getString(0); @@ -81,42 +83,42 @@ public Integer produceResult(InternalRow input) { } } - public static class JavaStrLenMagic implements ScalarFunction { - @Override - public DataType[] inputTypes() { - return new DataType[] { DataTypes.StringType }; + public static class JavaStrLenMagic extends JavaStrLenBase { + public int invoke(UTF8String str) { + return str.toString().length(); } + } - @Override - public DataType resultType() { - return DataTypes.IntegerType; + public static class JavaStrLenStaticMagic extends JavaStrLenBase { + public static int invoke(UTF8String str) { + return str.toString().length(); } + } + public static class JavaStrLenBoth extends JavaStrLenBase { @Override - public String name() { - return "strlen"; + public Integer produceResult(InternalRow input) { + String str = input.getString(0); + return str.length(); } - public int invoke(UTF8String str) { - return str.toString().length(); + return str.toString().length() + 100; } } - public static class JavaStrLenNoImpl implements ScalarFunction { - @Override - public DataType[] inputTypes() { - return new DataType[] { DataTypes.StringType }; + // even though the static magic method is present, it has incorrect parameter type and so Spark + // should fallback to the non-static magic method + public static class JavaStrLenBadStaticMagic extends JavaStrLenBase { + public static int invoke(String str) { + return str.length(); } - @Override - public DataType resultType() { - return DataTypes.IntegerType; + public int invoke(UTF8String str) { + return str.toString().length() + 100; } + } - @Override - public String name() { - return "strlen"; - } + public static class JavaStrLenNoImpl extends JavaStrLenBase { } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index fe856ffecb84a..b269da39daf38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -183,6 +183,28 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) } + test("scalar function: static magic method in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenStaticMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: magic method should take higher precedence in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenBoth)) + // to differentiate, the static method returns string length + 100 + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(103) :: Nil) + } + + test("scalar function: bad static magic method should fallback to non-static") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenBadStaticMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(103) :: Nil) + } + test("scalar function: no implementation found in Java") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "strlen"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala new file mode 100644 index 0000000000000..9328a9a8e93e3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.functions + +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddStaticMagic} + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType, NumericType, StructType} + +/** + * Benchmark to measure DataSourceV2 UDF performance + * {{{ + * To run this benchmark: + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/V2FunctionBenchmark-results.txt". + * }}} + * '''NOTE''': to update the result of this benchmark, please use Github benchmark action: + * https://spark.apache.org/developer-tools.html#github-workflow-benchmarks + */ +object V2FunctionBenchmark extends SqlBasedBenchmark { + val catalogName: String = "benchmark_catalog" + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val N = 500L * 1000 * 1000 + Seq(true, false).foreach { codegenEnabled => + Seq(true, false).foreach { resultNullable => + scalarFunctionBenchmark(N, codegenEnabled = codegenEnabled, + resultNullable = resultNullable) + } + } + } + + private def scalarFunctionBenchmark( + N: Long, + codegenEnabled: Boolean, + resultNullable: Boolean): Unit = { + withSQLConf(s"spark.sql.catalog.$catalogName" -> classOf[InMemoryCatalog].getName) { + createFunction("java_long_add_default", + new JavaLongAdd(new JavaLongAddDefault(resultNullable))) + createFunction("java_long_add_magic", new JavaLongAdd(new JavaLongAddMagic(resultNullable))) + createFunction("java_long_add_static_magic", + new JavaLongAdd(new JavaLongAddStaticMagic(resultNullable))) + createFunction("scala_long_add_default", + LongAddUnbound(new LongAddWithProduceResult(resultNullable))) + createFunction("scala_long_add_magic", LongAddUnbound(new LongAddWithMagic(resultNullable))) + + val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString, + SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) { + val name = s"scalar function (long + long) -> long, result_nullable = $resultNullable " + + s"codegen = $codegenEnabled" + val benchmark = new Benchmark(name, N, output = output) + benchmark.addCase(s"native_long_add", numIters = 3) { _ => + spark.range(N).select(Column(NativeAdd($"id".expr, $"id".expr, resultNullable))).noop() + } + Seq("java_long_add_default", "java_long_add_magic", "java_long_add_static_magic", + "scala_long_add_default", "scala_long_add_magic").foreach { functionName => + benchmark.addCase(s"$functionName", numIters = 3) { _ => + spark.range(N).selectExpr(s"$catalogName.$functionName(id, id)").noop() + } + } + benchmark.run() + } + } + } + + private def createFunction(name: String, fn: UnboundFunction): Unit = { + val catalog = spark.sessionState.catalogManager.catalog(catalogName) + val ident = Identifier.of(Array.empty, name) + catalog.asInstanceOf[InMemoryCatalog].createFunction(ident, fn) + } + + case class NativeAdd( + left: Expression, + right: Expression, + override val nullable: Boolean) extends BinaryArithmetic { + override protected val failOnError: Boolean = true + override def inputType: AbstractDataType = NumericType + override def symbol: String = "+" + override def exactMathMethod: Option[String] = Some("addExact") + + private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError) + protected override def nullSafeEval(input1: Any, input2: Any): Any = + numeric.plus(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): NativeAdd = copy(left = newLeft, right = newRight) + } + + case class LongAddUnbound(impl: ScalarFunction[Long]) extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = impl + override def description(): String = name() + override def name(): String = "long_add_unbound" + } + + abstract class LongAddBase(resultNullable: Boolean) extends ScalarFunction[Long] { + override def inputTypes(): Array[DataType] = Array(LongType, LongType) + override def resultType(): DataType = LongType + override def isResultNullable: Boolean = resultNullable + } + + class LongAddWithProduceResult(resultNullable: Boolean) extends LongAddBase(resultNullable) { + override def produceResult(input: InternalRow): Long = { + input.getLong(0) + input.getLong(1) + } + override def name(): String = "long_add_default" + } + + class LongAddWithMagic(resultNullable: Boolean) extends LongAddBase(resultNullable) { + def invoke(left: Long, right: Long): Long = { + left + right + } + override def name(): String = "long_add_magic" + } +} +