Skip to content

Commit

Permalink
[SPARK-35261][SQL] Support static magic method for stateless Java Sca…
Browse files Browse the repository at this point in the history
…larFunction

### What changes were proposed in this pull request?

This allows `ScalarFunction` implemented in Java to optionally specify the magic method `invoke` to be static, which can be used if the UDF is stateless. Comparing to the non-static method, it can potentially give better performance due to elimination of dynamic dispatch, etc.

Also added a benchmark to measure performance of: the default `produceResult`, non-static magic method and static magic method.

### Why are the changes needed?

For UDFs that are stateless (e.g., no need to maintain intermediate state between each function call), it's better to allow users to implement the UDF function as static method which could potentially give better performance.

### Does this PR introduce _any_ user-facing change?

Yes. Spark users can now have the choice to define static magic method for `ScalarFunction` when it is written in Java and when the UDF is stateless.

### How was this patch tested?

Added new UT.

Closes #32407 from sunchao/SPARK-35261.

Authored-by: Chao Sun <sunchao@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
sunchao authored and dongjoon-hyun committed May 8, 2021
1 parent b4ec9e2 commit f47e0f8
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,62 @@
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* The mapping between {@link DataType} and the corresponding JVM type is defined below.
* <p>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
* {@link UnsupportedOperationException}. Users can choose to override this method, or implement
* a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
* instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
* reflection and will also provide better performance in general, due to optimizations such as
* codegen, removal of Java boxing, etc.
*
* {@link UnsupportedOperationException}. Users must choose to either override this method, or
* implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters
* instead of a {@link InternalRow}. The magic method approach is generally recommended because it
* provides better performance over the default {@link #produceResult}, due to optimizations such
* as whole-stage codegen, elimination of Java boxing, etc.
* <p>
* In addition, for stateless Java functions, users can optionally define the
* {@link #MAGIC_METHOD_NAME} as a static method, which further avoids certain runtime costs such
* as Java dynamic dispatch.
* <p>
* For example, a scalar UDF for adding two integers can be defined as follow with the magic
* method approach:
*
* <pre>
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
* public DataType[] inputTypes() {
* return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
* }
* public int invoke(int left, int right) {
* return left + right;
* }
* }
* </pre>
* In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
* {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
* first converting the actual input SQL data types to their corresponding Java types following
* the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
* the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
* In the above, since {@link #MAGIC_METHOD_NAME} is defined, and also that it has
* matching parameter types and return type, Spark will use it to evaluate inputs.
* <p>
* As another example, in the following:
* <pre>
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
* public DataType[] inputTypes() {
* return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
* }
* public static int invoke(int left, int right) {
* return left + right;
* }
* public Integer produceResult(InternalRow input) {
* return input.getInt(0) + input.getInt(1);
* }
* }
* </pre>
*
* the class defines both the magic method and the {@link #produceResult}, and Spark will use
* {@link #MAGIC_METHOD_NAME} over the {@link #produceResult(InternalRow)} as it takes higher
* precedence. Also note that the magic method is annotated as a static method in this case.
* <p>
* Resolution on magic method is done during query analysis, where Spark looks up the magic
* method by first converting the actual input SQL data types to their corresponding Java types
* following the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name and the Java types.
* <p>
* The following are the mapping from {@link DataType SQL data type} to Java type through
* the magic method approach:
* The following are the mapping from {@link DataType SQL data type} to Java type which is used
* by Spark to infer parameter types for the magic methods as well as return value type for
* {@link #produceResult}:
* <ul>
* <li>{@link org.apache.spark.sql.types.BooleanType}: {@code boolean}</li>
* <li>{@link org.apache.spark.sql.types.ByteType}: {@code byte}</li>
Expand All @@ -80,7 +109,8 @@
* {@link org.apache.spark.sql.catalyst.util.MapData}</li>
* </ul>
*
* @param <R> the JVM type of result values
* @param <R> the JVM type of result values, MUST be consistent with the {@link DataType}
* returned via {@link #resultType()}, according to the mapping above.
*/
public interface ScalarFunction<R> extends BoundFunction {
String MAGIC_METHOD_NAME = "invoke";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Method
import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
Expand Down Expand Up @@ -2181,6 +2181,9 @@ class Analyzer(override val catalogManager: CatalogManager)
// match the input type through `BoundFunction.inputTypes`.
val argClasses = inputType.fields.map(_.dataType)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
Expand Down
44 changes: 44 additions & 0 deletions sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 17789 18405 580 28.1 35.6 1.0X
java_long_add_default 85058 87073 NaN 5.9 170.1 0.2X
java_long_add_magic 20262 20641 352 24.7 40.5 0.9X
java_long_add_static_magic 19458 19524 105 25.7 38.9 0.9X
scala_long_add_default 85892 86496 560 5.8 171.8 0.2X
scala_long_add_magic 20164 20330 212 24.8 40.3 0.9X

OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 18290 18467 157 27.3 36.6 1.0X
java_long_add_default 82415 82687 270 6.1 164.8 0.2X
java_long_add_magic 19941 20032 85 25.1 39.9 0.9X
java_long_add_static_magic 17861 17940 92 28.0 35.7 1.0X
scala_long_add_default 83800 85639 NaN 6.0 167.6 0.2X
scala_long_add_magic 20103 20123 18 24.9 40.2 0.9X

OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 46039 46199 162 10.9 92.1 1.0X
java_long_add_default 113199 113773 720 4.4 226.4 0.4X
java_long_add_magic 158252 159419 1075 3.2 316.5 0.3X
java_long_add_static_magic 157162 157676 516 3.2 314.3 0.3X
scala_long_add_default 112363 113264 1503 4.4 224.7 0.4X
scala_long_add_magic 158122 159010 835 3.2 316.2 0.3X

OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 42685 42743 54 11.7 85.4 1.0X
java_long_add_default 92041 92236 202 5.4 184.1 0.5X
java_long_add_magic 148299 148722 397 3.4 296.6 0.3X
java_long_add_static_magic 140599 141064 442 3.6 281.2 0.3X
scala_long_add_default 91896 92980 959 5.4 183.8 0.5X
scala_long_add_magic 148031 148802 759 3.4 296.1 0.3X

44 changes: 44 additions & 0 deletions sql/core/benchmarks/V2FunctionBenchmark-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 10559 11585 903 47.4 21.1 1.0X
java_long_add_default 78979 80089 987 6.3 158.0 0.1X
java_long_add_magic 14061 14326 305 35.6 28.1 0.8X
java_long_add_static_magic 11971 12150 242 41.8 23.9 0.9X
scala_long_add_default 77254 78565 1254 6.5 154.5 0.1X
scala_long_add_magic 13174 13232 51 38.0 26.3 0.8X

OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 10489 10665 162 47.7 21.0 1.0X
java_long_add_default 66636 68422 NaN 7.5 133.3 0.2X
java_long_add_magic 13504 14213 883 37.0 27.0 0.8X
java_long_add_static_magic 11726 11984 240 42.6 23.5 0.9X
scala_long_add_default 75906 76130 196 6.6 151.8 0.1X
scala_long_add_magic 14480 14770 261 34.5 29.0 0.7X

OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 39178 39548 323 12.8 78.4 1.0X
java_long_add_default 84756 85509 1092 5.9 169.5 0.5X
java_long_add_magic 199140 200801 1823 2.5 398.3 0.2X
java_long_add_static_magic 203500 207050 NaN 2.5 407.0 0.2X
scala_long_add_default 101180 101421 387 4.9 202.4 0.4X
scala_long_add_magic 193277 197006 1138 2.6 386.6 0.2X

OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 37064 37333 235 13.5 74.1 1.0X
java_long_add_default 104439 107802 NaN 4.8 208.9 0.4X
java_long_add_magic 212496 214321 NaN 2.4 425.0 0.2X
java_long_add_static_magic 239551 240619 1652 2.1 479.1 0.2X
scala_long_add_default 122413 123171 788 4.1 244.8 0.3X
scala_long_add_magic 215912 222715 NaN 2.3 431.8 0.2X

Loading

0 comments on commit f47e0f8

Please sign in to comment.