From f47e0f83794fc9beee3c07dca4c0bb7e0eab81e4 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 7 May 2021 20:34:51 -0700 Subject: [PATCH] [SPARK-35261][SQL] Support static magic method for stateless Java ScalarFunction ### What changes were proposed in this pull request? This allows `ScalarFunction` implemented in Java to optionally specify the magic method `invoke` to be static, which can be used if the UDF is stateless. Comparing to the non-static method, it can potentially give better performance due to elimination of dynamic dispatch, etc. Also added a benchmark to measure performance of: the default `produceResult`, non-static magic method and static magic method. ### Why are the changes needed? For UDFs that are stateless (e.g., no need to maintain intermediate state between each function call), it's better to allow users to implement the UDF function as static method which could potentially give better performance. ### Does this PR introduce _any_ user-facing change? Yes. Spark users can now have the choice to define static magic method for `ScalarFunction` when it is written in Java and when the UDF is stateless. ### How was this patch tested? Added new UT. Closes #32407 from sunchao/SPARK-35261. Authored-by: Chao Sun Signed-off-by: Dongjoon Hyun --- .../catalog/functions/ScalarFunction.java | 60 +++++-- .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../V2FunctionBenchmark-jdk11-results.txt | 44 ++++++ .../V2FunctionBenchmark-results.txt | 44 ++++++ .../catalog/functions/JavaLongAdd.java | 130 ++++++++++++++++ .../catalog/functions/JavaStrLen.java | 48 +++--- .../connector/DataSourceV2FunctionSuite.scala | 22 +++ .../functions/V2FunctionBenchmark.scala | 147 ++++++++++++++++++ 8 files changed, 461 insertions(+), 39 deletions(-) create mode 100644 sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt create mode 100644 sql/core/benchmarks/V2FunctionBenchmark-results.txt create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index ef755aae3fb0..858ab923490f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -29,33 +29,62 @@ *

* The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + * The mapping between {@link DataType} and the corresponding JVM type is defined below. *

* IMPORTANT: the default implementation of {@link #produceResult} throws - * {@link UnsupportedOperationException}. Users can choose to override this method, or implement - * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters - * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java - * reflection and will also provide better performance in general, due to optimizations such as - * codegen, removal of Java boxing, etc. - * + * {@link UnsupportedOperationException}. Users must choose to either override this method, or + * implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters + * instead of a {@link InternalRow}. The magic method approach is generally recommended because it + * provides better performance over the default {@link #produceResult}, due to optimizations such + * as whole-stage codegen, elimination of Java boxing, etc. + *

+ * In addition, for stateless Java functions, users can optionally define the + * {@link #MAGIC_METHOD_NAME} as a static method, which further avoids certain runtime costs such + * as Java dynamic dispatch. + *

* For example, a scalar UDF for adding two integers can be defined as follow with the magic * method approach: * *

  *   public class IntegerAdd implements{@code ScalarFunction} {
+ *     public DataType[] inputTypes() {
+ *       return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
+ *     }
  *     public int invoke(int left, int right) {
  *       return left + right;
  *     }
  *   }
  * 
- * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over - * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by - * first converting the actual input SQL data types to their corresponding Java types following - * the mapping defined below, and then checking if there is a matching method from all the - * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and - * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}. + * In the above, since {@link #MAGIC_METHOD_NAME} is defined, and also that it has + * matching parameter types and return type, Spark will use it to evaluate inputs. + *

+ * As another example, in the following: + *

+ *   public class IntegerAdd implements{@code ScalarFunction} {
+ *     public DataType[] inputTypes() {
+ *       return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
+ *     }
+ *     public static int invoke(int left, int right) {
+ *       return left + right;
+ *     }
+ *     public Integer produceResult(InternalRow input) {
+ *       return input.getInt(0) + input.getInt(1);
+ *     }
+ *   }
+ * 
+ * + * the class defines both the magic method and the {@link #produceResult}, and Spark will use + * {@link #MAGIC_METHOD_NAME} over the {@link #produceResult(InternalRow)} as it takes higher + * precedence. Also note that the magic method is annotated as a static method in this case. + *

+ * Resolution on magic method is done during query analysis, where Spark looks up the magic + * method by first converting the actual input SQL data types to their corresponding Java types + * following the mapping defined below, and then checking if there is a matching method from all the + * declared methods in the UDF class, using method name and the Java types. *

- * The following are the mapping from {@link DataType SQL data type} to Java type through - * the magic method approach: + * The following are the mapping from {@link DataType SQL data type} to Java type which is used + * by Spark to infer parameter types for the magic methods as well as return value type for + * {@link #produceResult}: *

    *
  • {@link org.apache.spark.sql.types.BooleanType}: {@code boolean}
  • *
  • {@link org.apache.spark.sql.types.ByteType}: {@code byte}
  • @@ -80,7 +109,8 @@ * {@link org.apache.spark.sql.catalyst.util.MapData} *
* - * @param the JVM type of result values + * @param the JVM type of result values, MUST be consistent with the {@link DataType} + * returned via {@link #resultType()}, according to the mapping above. */ public interface ScalarFunction extends BoundFunction { String MAGIC_METHOD_NAME = "invoke"; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b3310635cdcb..757778b66c34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Method +import java.lang.reflect.{Method, Modifier} import java.util import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean @@ -2181,6 +2181,9 @@ class Analyzer(override val catalogManager: CatalogManager) // match the input type through `BoundFunction.inputTypes`. val argClasses = inputType.fields.map(_.dataType) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(m) if Modifier.isStatic(m.getModifiers) => + StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), + MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable) case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), diff --git a/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt b/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt new file mode 100644 index 000000000000..564742f653e5 --- /dev/null +++ b/sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt @@ -0,0 +1,44 @@ +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------ +native_long_add 17789 18405 580 28.1 35.6 1.0X +java_long_add_default 85058 87073 NaN 5.9 170.1 0.2X +java_long_add_magic 20262 20641 352 24.7 40.5 0.9X +java_long_add_static_magic 19458 19524 105 25.7 38.9 0.9X +scala_long_add_default 85892 86496 560 5.8 171.8 0.2X +scala_long_add_magic 20164 20330 212 24.8 40.3 0.9X + +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 18290 18467 157 27.3 36.6 1.0X +java_long_add_default 82415 82687 270 6.1 164.8 0.2X +java_long_add_magic 19941 20032 85 25.1 39.9 0.9X +java_long_add_static_magic 17861 17940 92 28.0 35.7 1.0X +scala_long_add_default 83800 85639 NaN 6.0 167.6 0.2X +scala_long_add_magic 20103 20123 18 24.9 40.2 0.9X + +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 46039 46199 162 10.9 92.1 1.0X +java_long_add_default 113199 113773 720 4.4 226.4 0.4X +java_long_add_magic 158252 159419 1075 3.2 316.5 0.3X +java_long_add_static_magic 157162 157676 516 3.2 314.3 0.3X +scala_long_add_default 112363 113264 1503 4.4 224.7 0.4X +scala_long_add_magic 158122 159010 835 3.2 316.2 0.3X + +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 42685 42743 54 11.7 85.4 1.0X +java_long_add_default 92041 92236 202 5.4 184.1 0.5X +java_long_add_magic 148299 148722 397 3.4 296.6 0.3X +java_long_add_static_magic 140599 141064 442 3.6 281.2 0.3X +scala_long_add_default 91896 92980 959 5.4 183.8 0.5X +scala_long_add_magic 148031 148802 759 3.4 296.1 0.3X + diff --git a/sql/core/benchmarks/V2FunctionBenchmark-results.txt b/sql/core/benchmarks/V2FunctionBenchmark-results.txt new file mode 100644 index 000000000000..2035aa3633b8 --- /dev/null +++ b/sql/core/benchmarks/V2FunctionBenchmark-results.txt @@ -0,0 +1,44 @@ +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------ +native_long_add 10559 11585 903 47.4 21.1 1.0X +java_long_add_default 78979 80089 987 6.3 158.0 0.1X +java_long_add_magic 14061 14326 305 35.6 28.1 0.8X +java_long_add_static_magic 11971 12150 242 41.8 23.9 0.9X +scala_long_add_default 77254 78565 1254 6.5 154.5 0.1X +scala_long_add_magic 13174 13232 51 38.0 26.3 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 10489 10665 162 47.7 21.0 1.0X +java_long_add_default 66636 68422 NaN 7.5 133.3 0.2X +java_long_add_magic 13504 14213 883 37.0 27.0 0.8X +java_long_add_static_magic 11726 11984 240 42.6 23.5 0.9X +scala_long_add_default 75906 76130 196 6.6 151.8 0.1X +scala_long_add_magic 14480 14770 261 34.5 29.0 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 39178 39548 323 12.8 78.4 1.0X +java_long_add_default 84756 85509 1092 5.9 169.5 0.5X +java_long_add_magic 199140 200801 1823 2.5 398.3 0.2X +java_long_add_static_magic 203500 207050 NaN 2.5 407.0 0.2X +scala_long_add_default 101180 101421 387 4.9 202.4 0.4X +scala_long_add_magic 193277 197006 1138 2.6 386.6 0.2X + +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz +scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +native_long_add 37064 37333 235 13.5 74.1 1.0X +java_long_add_default 104439 107802 NaN 4.8 208.9 0.4X +java_long_add_magic 212496 214321 NaN 2.4 425.0 0.2X +java_long_add_static_magic 239551 240619 1652 2.1 479.1 0.2X +scala_long_add_default 122413 123171 788 4.1 244.8 0.3X +scala_long_add_magic 215912 222715 NaN 2.3 431.8 0.2X + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java new file mode 100644 index 000000000000..e2e7136d6f44 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaLongAdd implements UnboundFunction { + private final ScalarFunction impl; + + public JavaLongAdd(ScalarFunction impl) { + this.impl = impl; + } + + @Override + public String name() { + return "long_add"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 2) { + throw new UnsupportedOperationException("Expect two arguments"); + } + StructField[] fields = inputType.fields(); + if (!(fields[0].dataType() instanceof LongType)) { + throw new UnsupportedOperationException("Expect first argument to be LongType"); + } + if (!(fields[1].dataType() instanceof LongType)) { + throw new UnsupportedOperationException("Expect second argument to be LongType"); + } + return impl; + } + + @Override + public String description() { + return "long_add"; + } + + private abstract static class JavaLongAddBase implements ScalarFunction { + private final boolean isResultNullable; + + JavaLongAddBase(boolean isResultNullable) { + this.isResultNullable = isResultNullable; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.LongType, DataTypes.LongType }; + } + + @Override + public DataType resultType() { + return DataTypes.LongType; + } + + @Override + public boolean isResultNullable() { + return isResultNullable; + } + } + + public static class JavaLongAddDefault extends JavaLongAddBase { + public JavaLongAddDefault(boolean isResultNullable) { + super(isResultNullable); + } + + @Override + public String name() { + return "long_add_default"; + } + + @Override + public Long produceResult(InternalRow input) { + return input.getLong(0) + input.getLong(1); + } + } + + public static class JavaLongAddMagic extends JavaLongAddBase { + public JavaLongAddMagic(boolean isResultNullable) { + super(isResultNullable); + } + + @Override + public String name() { + return "long_add_magic"; + } + + public long invoke(long left, long right) { + return left + right; + } + } + + public static class JavaLongAddStaticMagic extends JavaLongAddBase { + public JavaLongAddStaticMagic(boolean isResultNullable) { + super(isResultNullable); + } + + @Override + public String name() { + return "long_add_static_magic"; + } + + public static long invoke(long left, long right) { + return left + right; + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index 8b2d883a3703..7cd010b9365b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -58,7 +58,7 @@ public String description() { " strlen(string) -> int"; } - public static class JavaStrLenDefault implements ScalarFunction { + private abstract static class JavaStrLenBase implements ScalarFunction { @Override public DataType[] inputTypes() { return new DataType[] { DataTypes.StringType }; @@ -73,7 +73,9 @@ public DataType resultType() { public String name() { return "strlen"; } + } + public static class JavaStrLenDefault extends JavaStrLenBase { @Override public Integer produceResult(InternalRow input) { String str = input.getString(0); @@ -81,42 +83,42 @@ public Integer produceResult(InternalRow input) { } } - public static class JavaStrLenMagic implements ScalarFunction { - @Override - public DataType[] inputTypes() { - return new DataType[] { DataTypes.StringType }; + public static class JavaStrLenMagic extends JavaStrLenBase { + public int invoke(UTF8String str) { + return str.toString().length(); } + } - @Override - public DataType resultType() { - return DataTypes.IntegerType; + public static class JavaStrLenStaticMagic extends JavaStrLenBase { + public static int invoke(UTF8String str) { + return str.toString().length(); } + } + public static class JavaStrLenBoth extends JavaStrLenBase { @Override - public String name() { - return "strlen"; + public Integer produceResult(InternalRow input) { + String str = input.getString(0); + return str.length(); } - public int invoke(UTF8String str) { - return str.toString().length(); + return str.toString().length() + 100; } } - public static class JavaStrLenNoImpl implements ScalarFunction { - @Override - public DataType[] inputTypes() { - return new DataType[] { DataTypes.StringType }; + // even though the static magic method is present, it has incorrect parameter type and so Spark + // should fallback to the non-static magic method + public static class JavaStrLenBadStaticMagic extends JavaStrLenBase { + public static int invoke(String str) { + return str.length(); } - @Override - public DataType resultType() { - return DataTypes.IntegerType; + public int invoke(UTF8String str) { + return str.toString().length() + 100; } + } - @Override - public String name() { - return "strlen"; - } + public static class JavaStrLenNoImpl extends JavaStrLenBase { } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index fe856ffecb84..b269da39daf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -183,6 +183,28 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) } + test("scalar function: static magic method in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenStaticMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: magic method should take higher precedence in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenBoth)) + // to differentiate, the static method returns string length + 100 + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(103) :: Nil) + } + + test("scalar function: bad static magic method should fallback to non-static") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenBadStaticMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(103) :: Nil) + } + test("scalar function: no implementation found in Java") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "strlen"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala new file mode 100644 index 000000000000..9328a9a8e93e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.functions + +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddStaticMagic} + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType, NumericType, StructType} + +/** + * Benchmark to measure DataSourceV2 UDF performance + * {{{ + * To run this benchmark: + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/V2FunctionBenchmark-results.txt". + * }}} + * '''NOTE''': to update the result of this benchmark, please use Github benchmark action: + * https://spark.apache.org/developer-tools.html#github-workflow-benchmarks + */ +object V2FunctionBenchmark extends SqlBasedBenchmark { + val catalogName: String = "benchmark_catalog" + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val N = 500L * 1000 * 1000 + Seq(true, false).foreach { codegenEnabled => + Seq(true, false).foreach { resultNullable => + scalarFunctionBenchmark(N, codegenEnabled = codegenEnabled, + resultNullable = resultNullable) + } + } + } + + private def scalarFunctionBenchmark( + N: Long, + codegenEnabled: Boolean, + resultNullable: Boolean): Unit = { + withSQLConf(s"spark.sql.catalog.$catalogName" -> classOf[InMemoryCatalog].getName) { + createFunction("java_long_add_default", + new JavaLongAdd(new JavaLongAddDefault(resultNullable))) + createFunction("java_long_add_magic", new JavaLongAdd(new JavaLongAddMagic(resultNullable))) + createFunction("java_long_add_static_magic", + new JavaLongAdd(new JavaLongAddStaticMagic(resultNullable))) + createFunction("scala_long_add_default", + LongAddUnbound(new LongAddWithProduceResult(resultNullable))) + createFunction("scala_long_add_magic", LongAddUnbound(new LongAddWithMagic(resultNullable))) + + val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString, + SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) { + val name = s"scalar function (long + long) -> long, result_nullable = $resultNullable " + + s"codegen = $codegenEnabled" + val benchmark = new Benchmark(name, N, output = output) + benchmark.addCase(s"native_long_add", numIters = 3) { _ => + spark.range(N).select(Column(NativeAdd($"id".expr, $"id".expr, resultNullable))).noop() + } + Seq("java_long_add_default", "java_long_add_magic", "java_long_add_static_magic", + "scala_long_add_default", "scala_long_add_magic").foreach { functionName => + benchmark.addCase(s"$functionName", numIters = 3) { _ => + spark.range(N).selectExpr(s"$catalogName.$functionName(id, id)").noop() + } + } + benchmark.run() + } + } + } + + private def createFunction(name: String, fn: UnboundFunction): Unit = { + val catalog = spark.sessionState.catalogManager.catalog(catalogName) + val ident = Identifier.of(Array.empty, name) + catalog.asInstanceOf[InMemoryCatalog].createFunction(ident, fn) + } + + case class NativeAdd( + left: Expression, + right: Expression, + override val nullable: Boolean) extends BinaryArithmetic { + override protected val failOnError: Boolean = true + override def inputType: AbstractDataType = NumericType + override def symbol: String = "+" + override def exactMathMethod: Option[String] = Some("addExact") + + private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError) + protected override def nullSafeEval(input1: Any, input2: Any): Any = + numeric.plus(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): NativeAdd = copy(left = newLeft, right = newRight) + } + + case class LongAddUnbound(impl: ScalarFunction[Long]) extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = impl + override def description(): String = name() + override def name(): String = "long_add_unbound" + } + + abstract class LongAddBase(resultNullable: Boolean) extends ScalarFunction[Long] { + override def inputTypes(): Array[DataType] = Array(LongType, LongType) + override def resultType(): DataType = LongType + override def isResultNullable: Boolean = resultNullable + } + + class LongAddWithProduceResult(resultNullable: Boolean) extends LongAddBase(resultNullable) { + override def produceResult(input: InternalRow): Long = { + input.getLong(0) + input.getLong(1) + } + override def name(): String = "long_add_default" + } + + class LongAddWithMagic(resultNullable: Boolean) extends LongAddBase(resultNullable) { + def invoke(left: Long, right: Long): Long = { + left + right + } + override def name(): String = "long_add_magic" + } +} +