Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35261][SQL] Support static magic method for stateless Java ScalarFunction #32407

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,62 @@
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* The mapping between {@link DataType} and the corresponding JVM type is defined below.
* <p>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
* {@link UnsupportedOperationException}. Users can choose to override this method, or implement
* a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
* instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
* reflection and will also provide better performance in general, due to optimizations such as
* codegen, removal of Java boxing, etc.
*
* {@link UnsupportedOperationException}. Users must choose to either override this method, or
* implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters
* instead of a {@link InternalRow}. The magic method approach is generally recommended because it
* provides better performance over the default {@link #produceResult}, due to optimizations such
* as whole-stage codegen, elimination of Java boxing, etc.
* <p>
* In addition, for stateless Java functions, users can optionally define the
* {@link #MAGIC_METHOD_NAME} as a static method, which further avoids certain runtime costs such
* as Java dynamic dispatch.
* <p>
* For example, a scalar UDF for adding two integers can be defined as follow with the magic
* method approach:
*
* <pre>
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
* public DataType[] inputTypes() {
* return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
* }
* public int invoke(int left, int right) {
sunchao marked this conversation as resolved.
Show resolved Hide resolved
* return left + right;
* }
* }
* </pre>
* In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
* {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
* first converting the actual input SQL data types to their corresponding Java types following
* the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
* the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
* In the above, since {@link #MAGIC_METHOD_NAME} is defined, and also that it has
* matching parameter types and return type, Spark will use it to evaluate inputs.
* <p>
* As another example, in the following:
* <pre>
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
* public DataType[] inputTypes() {
* return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
* }
* public static int invoke(int left, int right) {
* return left + right;
* }
* public Integer produceResult(InternalRow input) {
* return input.getInt(0) + input.getInt(1);
* }
* }
* </pre>
*
* the class defines both the magic method and the {@link #produceResult}, and Spark will use
* {@link #MAGIC_METHOD_NAME} over the {@link #produceResult(InternalRow)} as it takes higher
* precedence. Also note that the magic method is annotated as a static method in this case.
sunchao marked this conversation as resolved.
Show resolved Hide resolved
* <p>
* Resolution on magic method is done during query analysis, where Spark looks up the magic
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
* method by first converting the actual input SQL data types to their corresponding Java types
* following the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name and the Java types.
* <p>
* The following are the mapping from {@link DataType SQL data type} to Java type through
* the magic method approach:
* The following are the mapping from {@link DataType SQL data type} to Java type which is used
* by Spark to infer parameter types for the magic methods as well as return value type for
* {@link #produceResult}:
* <ul>
* <li>{@link org.apache.spark.sql.types.BooleanType}: {@code boolean}</li>
* <li>{@link org.apache.spark.sql.types.ByteType}: {@code byte}</li>
Expand All @@ -80,7 +109,8 @@
* {@link org.apache.spark.sql.catalyst.util.MapData}</li>
* </ul>
*
* @param <R> the JVM type of result values
* @param <R> the JVM type of result values, MUST be consistent with the {@link DataType}
* returned via {@link #resultType()}, according to the mapping above.
*/
public interface ScalarFunction<R> extends BoundFunction {
String MAGIC_METHOD_NAME = "invoke";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Method
import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
Expand Down Expand Up @@ -2181,6 +2181,9 @@ class Analyzer(override val catalogManager: CatalogManager)
// match the input type through `BoundFunction.inputTypes`.
val argClasses = inputType.fields.map(_.dataType)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd want to check if this method is actually static, otherwise there could be runtime error. However this only works for methods defined in Java; for Scala seems there is no easy way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala doesn't have the concept of truly static methods, right? The equivalent (object methods) are actually just instance methods on a singleton.

Copy link
Member Author

@sunchao sunchao Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct. The StaticInvoke calls the static method on the non-anonymous class which just forward to the non-static method defined in anonymous/singleton Java class (i.e., the class with $ at the end of its name).

For instance, for the LongAddWithStatic class, this is the method defined in LongAddWithStaticMagic.class:

  public static long staticInvoke(long, long);
    Code:
       0: getstatic     #16                 // Field org/apache/spark/sql/connector/functions/LongAddWithStaticMagic$.MODULE$:Lorg/apache/spark/sql/connector/functions/LongAddWithStaticMagic$;
       3: lload_0
       4: lload_2
       5: invokevirtual #51                 // Method org/apache/spark/sql/connector/functions/LongAddWithStaticMagic$.staticInvoke:(JJ)J
       8: lreturn

and the same method defined in the singleton class LongAddWithStaticMagic$:

  public long staticInvoke(long, long);
    Code:
       0: lload_1
       1: lload_3
       2: ladd
       3: lreturn

So I was expecting worse performance from Scala since it calls invokevirtual underneath while Java uses invokestatic, but the result doesn't look so. It could be that the performance is dominated by other factors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting, thanks for the explanation!

MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
Expand Down
44 changes: 44 additions & 0 deletions sql/core/benchmarks/V2FunctionBenchmark-jdk11-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 17789 18405 580 28.1 35.6 1.0X
java_long_add_default 85058 87073 NaN 5.9 170.1 0.2X
java_long_add_magic 20262 20641 352 24.7 40.5 0.9X
java_long_add_static_magic 19458 19524 105 25.7 38.9 0.9X
scala_long_add_default 85892 86496 560 5.8 171.8 0.2X
scala_long_add_magic 20164 20330 212 24.8 40.3 0.9X

OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 18290 18467 157 27.3 36.6 1.0X
java_long_add_default 82415 82687 270 6.1 164.8 0.2X
java_long_add_magic 19941 20032 85 25.1 39.9 0.9X
java_long_add_static_magic 17861 17940 92 28.0 35.7 1.0X
scala_long_add_default 83800 85639 NaN 6.0 167.6 0.2X
scala_long_add_magic 20103 20123 18 24.9 40.2 0.9X

OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 46039 46199 162 10.9 92.1 1.0X
java_long_add_default 113199 113773 720 4.4 226.4 0.4X
java_long_add_magic 158252 159419 1075 3.2 316.5 0.3X
java_long_add_static_magic 157162 157676 516 3.2 314.3 0.3X
scala_long_add_default 112363 113264 1503 4.4 224.7 0.4X
scala_long_add_magic 158122 159010 835 3.2 316.2 0.3X

OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 42685 42743 54 11.7 85.4 1.0X
java_long_add_default 92041 92236 202 5.4 184.1 0.5X
java_long_add_magic 148299 148722 397 3.4 296.6 0.3X
java_long_add_static_magic 140599 141064 442 3.6 281.2 0.3X
scala_long_add_default 91896 92980 959 5.4 183.8 0.5X
scala_long_add_magic 148031 148802 759 3.4 296.1 0.3X

44 changes: 44 additions & 0 deletions sql/core/benchmarks/V2FunctionBenchmark-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = true codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 10559 11585 903 47.4 21.1 1.0X
java_long_add_default 78979 80089 987 6.3 158.0 0.1X
java_long_add_magic 14061 14326 305 35.6 28.1 0.8X
java_long_add_static_magic 11971 12150 242 41.8 23.9 0.9X
scala_long_add_default 77254 78565 1254 6.5 154.5 0.1X
scala_long_add_magic 13174 13232 51 38.0 26.3 0.8X

OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = false codegen = true: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 10489 10665 162 47.7 21.0 1.0X
java_long_add_default 66636 68422 NaN 7.5 133.3 0.2X
java_long_add_magic 13504 14213 883 37.0 27.0 0.8X
java_long_add_static_magic 11726 11984 240 42.6 23.5 0.9X
scala_long_add_default 75906 76130 196 6.6 151.8 0.1X
scala_long_add_magic 14480 14770 261 34.5 29.0 0.7X

OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = true codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 39178 39548 323 12.8 78.4 1.0X
java_long_add_default 84756 85509 1092 5.9 169.5 0.5X
java_long_add_magic 199140 200801 1823 2.5 398.3 0.2X
java_long_add_static_magic 203500 207050 NaN 2.5 407.0 0.2X
scala_long_add_default 101180 101421 387 4.9 202.4 0.4X
scala_long_add_magic 193277 197006 1138 2.6 386.6 0.2X

OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
scalar function (long + long) -> long, result_nullable = false codegen = false: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------------
native_long_add 37064 37333 235 13.5 74.1 1.0X
java_long_add_default 104439 107802 NaN 4.8 208.9 0.4X
java_long_add_magic 212496 214321 NaN 2.4 425.0 0.2X
java_long_add_static_magic 239551 240619 1652 2.1 479.1 0.2X
scala_long_add_default 122413 123171 788 4.1 244.8 0.3X
scala_long_add_magic 215912 222715 NaN 2.3 431.8 0.2X

Loading