diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
index ef755aae3fb07..858ab923490fc 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
@@ -29,33 +29,62 @@
*
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
+ * The mapping between {@link DataType} and the corresponding JVM type is defined below.
*
* IMPORTANT: the default implementation of {@link #produceResult} throws
- * {@link UnsupportedOperationException}. Users can choose to override this method, or implement
- * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
- * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
- * reflection and will also provide better performance in general, due to optimizations such as
- * codegen, removal of Java boxing, etc.
- *
+ * {@link UnsupportedOperationException}. Users must choose to either override this method, or
+ * implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters
+ * instead of a {@link InternalRow}. The magic method approach is generally recommended because it
+ * provides better performance over the default {@link #produceResult}, due to optimizations such
+ * as whole-stage codegen, elimination of Java boxing, etc.
+ *
+ * In addition, for stateless Java functions, users can optionally define the
+ * {@link #MAGIC_METHOD_NAME} as a static method, which further avoids certain runtime costs such
+ * as Java dynamic dispatch.
+ *
* For example, a scalar UDF for adding two integers can be defined as follow with the magic
* method approach:
*
*
* public class IntegerAdd implements{@code ScalarFunction} {
+ * public DataType[] inputTypes() {
+ * return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
+ * }
* public int invoke(int left, int right) {
* return left + right;
* }
* }
*
- * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
- * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
- * first converting the actual input SQL data types to their corresponding Java types following
- * the mapping defined below, and then checking if there is a matching method from all the
- * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
- * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
+ * In the above, since {@link #MAGIC_METHOD_NAME} is defined, and also that it has
+ * matching parameter types and return type, Spark will use it to evaluate inputs.
+ *
+ * As another example, in the following:
+ *
+ * public class IntegerAdd implements{@code ScalarFunction} {
+ * public DataType[] inputTypes() {
+ * return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
+ * }
+ * public static int invoke(int left, int right) {
+ * return left + right;
+ * }
+ * public Integer produceResult(InternalRow input) {
+ * return input.getInt(0) + input.getInt(1);
+ * }
+ * }
+ *
+ *
+ * the class defines both the magic method and the {@link #produceResult}, and Spark will use
+ * {@link #MAGIC_METHOD_NAME} over the {@link #produceResult(InternalRow)} as it takes higher
+ * precedence. Also note that the magic method is annotated as a static method in this case.
+ *
+ * Resolution on magic method is done during query analysis, where Spark looks up the magic
+ * method by first converting the actual input SQL data types to their corresponding Java types
+ * following the mapping defined below, and then checking if there is a matching method from all the
+ * declared methods in the UDF class, using method name and the Java types.
*
- * The following are the mapping from {@link DataType SQL data type} to Java type through
- * the magic method approach:
+ * The following are the mapping from {@link DataType SQL data type} to Java type which is used
+ * by Spark to infer parameter types for the magic methods as well as return value type for
+ * {@link #produceResult}:
*
* - {@link org.apache.spark.sql.types.BooleanType}: {@code boolean}
* - {@link org.apache.spark.sql.types.ByteType}: {@code byte}
@@ -80,7 +109,8 @@
* {@link org.apache.spark.sql.catalyst.util.MapData}
*
*
- * @param the JVM type of result values
+ * @param the JVM type of result values, MUST be consistent with the {@link DataType}
+ * returned via {@link #resultType()}, according to the mapping above.
*/
public interface ScalarFunction extends BoundFunction {
String MAGIC_METHOD_NAME = "invoke";
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b3310635cdcb6..757778b66c345 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
-import java.lang.reflect.Method
+import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
@@ -2181,6 +2181,9 @@ class Analyzer(override val catalogManager: CatalogManager)
// match the input type through `BoundFunction.inputTypes`.
val argClasses = inputType.fields.map(_.dataType)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
+ case Some(m) if Modifier.isStatic(m.getModifiers) =>
+ StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
+ MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
diff --git a/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt b/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..564742f653e57
--- /dev/null
+++ b/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt
@@ -0,0 +1,44 @@
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 17789 18405 580 28.1 35.6 1.0X
+java_long_add_default 85058 87073 NaN 5.9 170.1 0.2X
+java_long_add_magic 20262 20641 352 24.7 40.5 0.9X
+java_long_add_static_magic 19458 19524 105 25.7 38.9 0.9X
+scala_long_add_default 85892 86496 560 5.8 171.8 0.2X
+scala_long_add_magic 20164 20330 212 24.8 40.3 0.9X
+
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+-------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 18290 18467 157 27.3 36.6 1.0X
+java_long_add_default 82415 82687 270 6.1 164.8 0.2X
+java_long_add_magic 19941 20032 85 25.1 39.9 0.9X
+java_long_add_static_magic 17861 17940 92 28.0 35.7 1.0X
+scala_long_add_default 83800 85639 NaN 6.0 167.6 0.2X
+scala_long_add_magic 20103 20123 18 24.9 40.2 0.9X
+
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+-------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 46039 46199 162 10.9 92.1 1.0X
+java_long_add_default 113199 113773 720 4.4 226.4 0.4X
+java_long_add_magic 158252 159419 1075 3.2 316.5 0.3X
+java_long_add_static_magic 157162 157676 516 3.2 314.3 0.3X
+scala_long_add_default 112363 113264 1503 4.4 224.7 0.4X
+scala_long_add_magic 158122 159010 835 3.2 316.2 0.3X
+
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+--------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 42685 42743 54 11.7 85.4 1.0X
+java_long_add_default 92041 92236 202 5.4 184.1 0.5X
+java_long_add_magic 148299 148722 397 3.4 296.6 0.3X
+java_long_add_static_magic 140599 141064 442 3.6 281.2 0.3X
+scala_long_add_default 91896 92980 959 5.4 183.8 0.5X
+scala_long_add_magic 148031 148802 759 3.4 296.1 0.3X
+
diff --git a/sql/core/benchmarks/V2FunctionBenchmark-results.txt b/sql/core/benchmarks/V2FunctionBenchmark-results.txt
new file mode 100644
index 0000000000000..2035aa3633b80
--- /dev/null
+++ b/sql/core/benchmarks/V2FunctionBenchmark-results.txt
@@ -0,0 +1,44 @@
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
+scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 10559 11585 903 47.4 21.1 1.0X
+java_long_add_default 78979 80089 987 6.3 158.0 0.1X
+java_long_add_magic 14061 14326 305 35.6 28.1 0.8X
+java_long_add_static_magic 11971 12150 242 41.8 23.9 0.9X
+scala_long_add_default 77254 78565 1254 6.5 154.5 0.1X
+scala_long_add_magic 13174 13232 51 38.0 26.3 0.8X
+
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
+scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+-------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 10489 10665 162 47.7 21.0 1.0X
+java_long_add_default 66636 68422 NaN 7.5 133.3 0.2X
+java_long_add_magic 13504 14213 883 37.0 27.0 0.8X
+java_long_add_static_magic 11726 11984 240 42.6 23.5 0.9X
+scala_long_add_default 75906 76130 196 6.6 151.8 0.1X
+scala_long_add_magic 14480 14770 261 34.5 29.0 0.7X
+
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
+scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+-------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 39178 39548 323 12.8 78.4 1.0X
+java_long_add_default 84756 85509 1092 5.9 169.5 0.5X
+java_long_add_magic 199140 200801 1823 2.5 398.3 0.2X
+java_long_add_static_magic 203500 207050 NaN 2.5 407.0 0.2X
+scala_long_add_default 101180 101421 387 4.9 202.4 0.4X
+scala_long_add_magic 193277 197006 1138 2.6 386.6 0.2X
+
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
+scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+--------------------------------------------------------------------------------------------------------------------------------------------------------------
+native_long_add 37064 37333 235 13.5 74.1 1.0X
+java_long_add_default 104439 107802 NaN 4.8 208.9 0.4X
+java_long_add_magic 212496 214321 NaN 2.4 425.0 0.2X
+java_long_add_static_magic 239551 240619 1652 2.1 479.1 0.2X
+scala_long_add_default 122413 123171 788 4.1 244.8 0.3X
+scala_long_add_magic 215912 222715 NaN 2.3 431.8 0.2X
+
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
new file mode 100644
index 0000000000000..e2e7136d6f44c
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaLongAdd implements UnboundFunction {
+ private final ScalarFunction impl;
+
+ public JavaLongAdd(ScalarFunction impl) {
+ this.impl = impl;
+ }
+
+ @Override
+ public String name() {
+ return "long_add";
+ }
+
+ @Override
+ public BoundFunction bind(StructType inputType) {
+ if (inputType.fields().length != 2) {
+ throw new UnsupportedOperationException("Expect two arguments");
+ }
+ StructField[] fields = inputType.fields();
+ if (!(fields[0].dataType() instanceof LongType)) {
+ throw new UnsupportedOperationException("Expect first argument to be LongType");
+ }
+ if (!(fields[1].dataType() instanceof LongType)) {
+ throw new UnsupportedOperationException("Expect second argument to be LongType");
+ }
+ return impl;
+ }
+
+ @Override
+ public String description() {
+ return "long_add";
+ }
+
+ private abstract static class JavaLongAddBase implements ScalarFunction {
+ private final boolean isResultNullable;
+
+ JavaLongAddBase(boolean isResultNullable) {
+ this.isResultNullable = isResultNullable;
+ }
+
+ @Override
+ public DataType[] inputTypes() {
+ return new DataType[] { DataTypes.LongType, DataTypes.LongType };
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.LongType;
+ }
+
+ @Override
+ public boolean isResultNullable() {
+ return isResultNullable;
+ }
+ }
+
+ public static class JavaLongAddDefault extends JavaLongAddBase {
+ public JavaLongAddDefault(boolean isResultNullable) {
+ super(isResultNullable);
+ }
+
+ @Override
+ public String name() {
+ return "long_add_default";
+ }
+
+ @Override
+ public Long produceResult(InternalRow input) {
+ return input.getLong(0) + input.getLong(1);
+ }
+ }
+
+ public static class JavaLongAddMagic extends JavaLongAddBase {
+ public JavaLongAddMagic(boolean isResultNullable) {
+ super(isResultNullable);
+ }
+
+ @Override
+ public String name() {
+ return "long_add_magic";
+ }
+
+ public long invoke(long left, long right) {
+ return left + right;
+ }
+ }
+
+ public static class JavaLongAddStaticMagic extends JavaLongAddBase {
+ public JavaLongAddStaticMagic(boolean isResultNullable) {
+ super(isResultNullable);
+ }
+
+ @Override
+ public String name() {
+ return "long_add_static_magic";
+ }
+
+ public static long invoke(long left, long right) {
+ return left + right;
+ }
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
index 8b2d883a3703f..7cd010b9365bb 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
@@ -58,7 +58,7 @@ public String description() {
" strlen(string) -> int";
}
- public static class JavaStrLenDefault implements ScalarFunction {
+ private abstract static class JavaStrLenBase implements ScalarFunction {
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.StringType };
@@ -73,7 +73,9 @@ public DataType resultType() {
public String name() {
return "strlen";
}
+ }
+ public static class JavaStrLenDefault extends JavaStrLenBase {
@Override
public Integer produceResult(InternalRow input) {
String str = input.getString(0);
@@ -81,42 +83,42 @@ public Integer produceResult(InternalRow input) {
}
}
- public static class JavaStrLenMagic implements ScalarFunction {
- @Override
- public DataType[] inputTypes() {
- return new DataType[] { DataTypes.StringType };
+ public static class JavaStrLenMagic extends JavaStrLenBase {
+ public int invoke(UTF8String str) {
+ return str.toString().length();
}
+ }
- @Override
- public DataType resultType() {
- return DataTypes.IntegerType;
+ public static class JavaStrLenStaticMagic extends JavaStrLenBase {
+ public static int invoke(UTF8String str) {
+ return str.toString().length();
}
+ }
+ public static class JavaStrLenBoth extends JavaStrLenBase {
@Override
- public String name() {
- return "strlen";
+ public Integer produceResult(InternalRow input) {
+ String str = input.getString(0);
+ return str.length();
}
-
public int invoke(UTF8String str) {
- return str.toString().length();
+ return str.toString().length() + 100;
}
}
- public static class JavaStrLenNoImpl implements ScalarFunction {
- @Override
- public DataType[] inputTypes() {
- return new DataType[] { DataTypes.StringType };
+ // even though the static magic method is present, it has incorrect parameter type and so Spark
+ // should fallback to the non-static magic method
+ public static class JavaStrLenBadStaticMagic extends JavaStrLenBase {
+ public static int invoke(String str) {
+ return str.length();
}
- @Override
- public DataType resultType() {
- return DataTypes.IntegerType;
+ public int invoke(UTF8String str) {
+ return str.toString().length() + 100;
}
+ }
- @Override
- public String name() {
- return "strlen";
- }
+ public static class JavaStrLenNoImpl extends JavaStrLenBase {
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index fe856ffecb84a..b269da39daf38 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -183,6 +183,28 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
}
+ test("scalar function: static magic method in Java") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"),
+ new JavaStrLen(new JavaStrLenStaticMagic))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
+ }
+
+ test("scalar function: magic method should take higher precedence in Java") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"),
+ new JavaStrLen(new JavaStrLenBoth))
+ // to differentiate, the static method returns string length + 100
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(103) :: Nil)
+ }
+
+ test("scalar function: bad static magic method should fallback to non-static") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"),
+ new JavaStrLen(new JavaStrLenBadStaticMagic))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(103) :: Nil)
+ }
+
test("scalar function: no implementation found in Java") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
new file mode 100644
index 0000000000000..9328a9a8e93e3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.functions
+
+import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd
+import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddStaticMagic}
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog}
+import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
+import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType, NumericType, StructType}
+
+/**
+ * Benchmark to measure DataSourceV2 UDF performance
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt:
+ * bin/spark-submit --class
+ * --jars ,
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/V2FunctionBenchmark-results.txt".
+ * }}}
+ * '''NOTE''': to update the result of this benchmark, please use Github benchmark action:
+ * https://spark.apache.org/developer-tools.html#github-workflow-benchmarks
+ */
+object V2FunctionBenchmark extends SqlBasedBenchmark {
+ val catalogName: String = "benchmark_catalog"
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ val N = 500L * 1000 * 1000
+ Seq(true, false).foreach { codegenEnabled =>
+ Seq(true, false).foreach { resultNullable =>
+ scalarFunctionBenchmark(N, codegenEnabled = codegenEnabled,
+ resultNullable = resultNullable)
+ }
+ }
+ }
+
+ private def scalarFunctionBenchmark(
+ N: Long,
+ codegenEnabled: Boolean,
+ resultNullable: Boolean): Unit = {
+ withSQLConf(s"spark.sql.catalog.$catalogName" -> classOf[InMemoryCatalog].getName) {
+ createFunction("java_long_add_default",
+ new JavaLongAdd(new JavaLongAddDefault(resultNullable)))
+ createFunction("java_long_add_magic", new JavaLongAdd(new JavaLongAddMagic(resultNullable)))
+ createFunction("java_long_add_static_magic",
+ new JavaLongAdd(new JavaLongAddStaticMagic(resultNullable)))
+ createFunction("scala_long_add_default",
+ LongAddUnbound(new LongAddWithProduceResult(resultNullable)))
+ createFunction("scala_long_add_magic", LongAddUnbound(new LongAddWithMagic(resultNullable)))
+
+ val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString,
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) {
+ val name = s"scalar function (long + long) -> long, result_nullable = $resultNullable " +
+ s"codegen = $codegenEnabled"
+ val benchmark = new Benchmark(name, N, output = output)
+ benchmark.addCase(s"native_long_add", numIters = 3) { _ =>
+ spark.range(N).select(Column(NativeAdd($"id".expr, $"id".expr, resultNullable))).noop()
+ }
+ Seq("java_long_add_default", "java_long_add_magic", "java_long_add_static_magic",
+ "scala_long_add_default", "scala_long_add_magic").foreach { functionName =>
+ benchmark.addCase(s"$functionName", numIters = 3) { _ =>
+ spark.range(N).selectExpr(s"$catalogName.$functionName(id, id)").noop()
+ }
+ }
+ benchmark.run()
+ }
+ }
+ }
+
+ private def createFunction(name: String, fn: UnboundFunction): Unit = {
+ val catalog = spark.sessionState.catalogManager.catalog(catalogName)
+ val ident = Identifier.of(Array.empty, name)
+ catalog.asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
+ }
+
+ case class NativeAdd(
+ left: Expression,
+ right: Expression,
+ override val nullable: Boolean) extends BinaryArithmetic {
+ override protected val failOnError: Boolean = true
+ override def inputType: AbstractDataType = NumericType
+ override def symbol: String = "+"
+ override def exactMathMethod: Option[String] = Some("addExact")
+
+ private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any =
+ numeric.plus(input1, input2)
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): NativeAdd = copy(left = newLeft, right = newRight)
+ }
+
+ case class LongAddUnbound(impl: ScalarFunction[Long]) extends UnboundFunction {
+ override def bind(inputType: StructType): BoundFunction = impl
+ override def description(): String = name()
+ override def name(): String = "long_add_unbound"
+ }
+
+ abstract class LongAddBase(resultNullable: Boolean) extends ScalarFunction[Long] {
+ override def inputTypes(): Array[DataType] = Array(LongType, LongType)
+ override def resultType(): DataType = LongType
+ override def isResultNullable: Boolean = resultNullable
+ }
+
+ class LongAddWithProduceResult(resultNullable: Boolean) extends LongAddBase(resultNullable) {
+ override def produceResult(input: InternalRow): Long = {
+ input.getLong(0) + input.getLong(1)
+ }
+ override def name(): String = "long_add_default"
+ }
+
+ class LongAddWithMagic(resultNullable: Boolean) extends LongAddBase(resultNullable) {
+ def invoke(left: Long, right: Long): Long = {
+ left + right
+ }
+ override def name(): String = "long_add_magic"
+ }
+}
+