-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-34981][SQL] Implement V2 function resolution and evaluation #32082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
68fa1a9
27d5a20
8fa2b7b
1edca4a
1607d0e
412f191
f4a3f32
1b94e65
310cdca
57b3c25
de726d0
ee56ea7
91038d1
eeccf6b
0ca1ca5
465737c
68e1001
f25b5e6
c453b64
790d27f
c18715f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,17 +23,67 @@ | |
| /** | ||
| * Interface for a function that produces a result value for each input row. | ||
| * <p> | ||
| * For each input row, Spark will call a produceResult method that corresponds to the | ||
| * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by | ||
| * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call | ||
| * {@link #produceResult(InternalRow)}. | ||
| * To evaluate each input row, Spark will first try to lookup and use a "magic method" (described | ||
| * below) through Java reflection. If the method is not found, Spark will call | ||
| * {@link #produceResult(InternalRow)} as a fallback approach. | ||
| * <p> | ||
| * The JVM type of result values produced by this function must be the type used by Spark's | ||
| * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. | ||
| * <p> | ||
| * <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws | ||
| * {@link UnsupportedOperationException}. Users can choose to override this method, or implement | ||
| * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java | ||
| * reflection and will also provide better performance in general, due to optimizations such as | ||
| * codegen, removal of Java boxing, etc. | ||
| * | ||
| * For example, a scalar UDF for adding two integers can be defined as follow with the magic | ||
| * method approach: | ||
| * | ||
| * <pre> | ||
| * public class IntegerAdd implements{@code ScalarFunction<Integer>} { | ||
| * public int invoke(int left, int right) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan I think we can also consider adding another "static invoke" API for those stateless UDFs. From the benchmark you did sometime back it seems this can give a decent performance improvement. WDYT?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sunchao can you spend some time on the API design? I'd love to see this feature!
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure will do. It should similar to the current
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| * return left + right; | ||
| * } | ||
| * } | ||
| * </pre> | ||
| * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over | ||
| * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by | ||
| * first converting the actual input SQL data types to their corresponding Java types following | ||
| * the mapping defined below, and then checking if there is a matching method from all the | ||
| * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and | ||
| * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}. | ||
| * <p> | ||
| * The following are the mapping from {@link DataType SQL data type} to Java type through | ||
| * the magic method approach: | ||
| * <ul> | ||
| * <li>{@link org.apache.spark.sql.types.BooleanType}: {@code boolean}</li> | ||
| * <li>{@link org.apache.spark.sql.types.ByteType}: {@code byte}</li> | ||
| * <li>{@link org.apache.spark.sql.types.ShortType}: {@code short}</li> | ||
| * <li>{@link org.apache.spark.sql.types.IntegerType}: {@code int}</li> | ||
| * <li>{@link org.apache.spark.sql.types.LongType}: {@code long}</li> | ||
| * <li>{@link org.apache.spark.sql.types.FloatType}: {@code float}</li> | ||
| * <li>{@link org.apache.spark.sql.types.DoubleType}: {@code double}</li> | ||
| * <li>{@link org.apache.spark.sql.types.StringType}: | ||
| * {@link org.apache.spark.unsafe.types.UTF8String}</li> | ||
| * <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li> | ||
| * <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li> | ||
| * <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li> | ||
| * <li>{@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}</li> | ||
| * <li>{@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}</li> | ||
| * <li>{@link org.apache.spark.sql.types.DecimalType}: | ||
| * {@link org.apache.spark.sql.types.Decimal}</li> | ||
| * <li>{@link org.apache.spark.sql.types.StructType}: {@link InternalRow}</li> | ||
| * <li>{@link org.apache.spark.sql.types.ArrayType}: | ||
| * {@link org.apache.spark.sql.catalyst.util.ArrayData}</li> | ||
| * <li>{@link org.apache.spark.sql.types.MapType}: | ||
| * {@link org.apache.spark.sql.catalyst.util.MapData}</li> | ||
| * </ul> | ||
| * | ||
| * @param <R> the JVM type of result values | ||
| */ | ||
| public interface ScalarFunction<R> extends BoundFunction { | ||
| String MAGIC_METHOD_NAME = "invoke"; | ||
|
||
|
|
||
| /** | ||
| * Applies the function to an input row to produce a value. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import java.lang.reflect.Method | ||
| import java.util | ||
| import java.util.Locale | ||
| import java.util.concurrent.atomic.AtomicBoolean | ||
|
|
@@ -29,7 +30,7 @@ import org.apache.spark.sql.AnalysisException | |
| import org.apache.spark.sql.catalyst._ | ||
| import org.apache.spark.sql.catalyst.catalog._ | ||
| import org.apache.spark.sql.catalyst.encoders.OuterScopes | ||
| import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _} | ||
| import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _} | ||
| import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.expressions.objects._ | ||
|
|
@@ -44,6 +45,8 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} | |
| import org.apache.spark.sql.connector.catalog._ | ||
| import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | ||
| import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} | ||
| import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} | ||
| import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME | ||
| import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} | ||
| import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} | ||
| import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation | ||
|
|
@@ -281,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| ResolveAggregateFunctions :: | ||
| TimeWindowing :: | ||
| ResolveInlineTables :: | ||
| ResolveHigherOrderFunctions(v1SessionCatalog) :: | ||
| ResolveHigherOrderFunctions(catalogManager) :: | ||
| ResolveLambdaVariables :: | ||
| ResolveTimeZone :: | ||
| ResolveRandomSeed :: | ||
|
|
@@ -895,9 +898,10 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| } | ||
| } | ||
|
|
||
| // If we are resolving relations insides views, we need to expand single-part relation names with | ||
| // the current catalog and namespace of when the view was created. | ||
| private def expandRelationName(nameParts: Seq[String]): Seq[String] = { | ||
| // If we are resolving database objects (relations, functions, etc.) insides views, we may need to | ||
| // expand single or multi-part identifiers with the current catalog and namespace of when the | ||
| // view was created. | ||
| private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { | ||
| if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts | ||
|
|
||
| if (nameParts.length == 1) { | ||
|
|
@@ -1040,7 +1044,7 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| identifier: Seq[String], | ||
| options: CaseInsensitiveStringMap, | ||
| isStreaming: Boolean): Option[LogicalPlan] = | ||
| expandRelationName(identifier) match { | ||
| expandIdentifier(identifier) match { | ||
| case NonSessionCatalogAndIdentifier(catalog, ident) => | ||
| CatalogV2Util.loadTable(catalog, ident) match { | ||
| case Some(table) => | ||
|
|
@@ -1153,7 +1157,7 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| } | ||
|
|
||
| private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = { | ||
| expandRelationName(identifier) match { | ||
| expandIdentifier(identifier) match { | ||
| case SessionCatalogAndIdentifier(catalog, ident) => | ||
| CatalogV2Util.loadTable(catalog, ident).map { | ||
| case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW => | ||
|
|
@@ -1173,7 +1177,7 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| identifier: Seq[String], | ||
| options: CaseInsensitiveStringMap, | ||
| isStreaming: Boolean): Option[LogicalPlan] = { | ||
| expandRelationName(identifier) match { | ||
| expandIdentifier(identifier) match { | ||
| case SessionCatalogAndIdentifier(catalog, ident) => | ||
| lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map { | ||
| case v1Table: V1Table => | ||
|
|
@@ -1569,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| // results and confuse users if there is any null values. For count(t1.*, t2.*), it is | ||
| // still allowed, since it's well-defined in spark. | ||
| if (!conf.allowStarWithSingleTableIdentifierInCount && | ||
| f1.name.database.isEmpty && | ||
| f1.name.funcName == "count" && | ||
| f1.nameParts == Seq("count") && | ||
| f1.arguments.length == 1) { | ||
| f1.arguments.foreach { | ||
| case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) => | ||
|
|
@@ -1958,17 +1961,19 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| override def apply(plan: LogicalPlan): LogicalPlan = { | ||
| val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() | ||
| plan.resolveExpressions { | ||
| case f: UnresolvedFunction | ||
| if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f | ||
| case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f | ||
| case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) => | ||
| externalFunctionNameSet.add(normalizeFuncName(f.name)) | ||
| f | ||
| case f: UnresolvedFunction => | ||
| withPosition(f) { | ||
| throw new NoSuchFunctionException( | ||
| f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase), | ||
| f.name.funcName) | ||
| case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) => | ||
| if (externalFunctionNameSet.contains(normalizeFuncName(ident)) || | ||
| v1SessionCatalog.isRegisteredFunction(ident)) { | ||
| f | ||
| } else if (v1SessionCatalog.isPersistentFunction(ident)) { | ||
| externalFunctionNameSet.add(normalizeFuncName(ident)) | ||
| f | ||
| } else { | ||
| withPosition(f) { | ||
| throw new NoSuchFunctionException( | ||
| ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase), | ||
| ident.funcName) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -2016,9 +2021,10 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| name, other.getClass.getCanonicalName) | ||
| } | ||
| } | ||
| case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) => | ||
| withPosition(u) { | ||
| v1SessionCatalog.lookupFunction(funcId, arguments) match { | ||
|
|
||
| case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter, | ||
| ignoreNulls) => withPosition(u) { | ||
| v1SessionCatalog.lookupFunction(ident, arguments) match { | ||
| // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within | ||
| // the context of a Window clause. They do not need to be wrapped in an | ||
| // AggregateExpression. | ||
|
|
@@ -2095,9 +2101,123 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| case other => | ||
| other | ||
| } | ||
| } | ||
|
|
||
| case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) => | ||
| withPosition(u) { | ||
| expandIdentifier(nameParts) match { | ||
| case NonSessionCatalogAndIdentifier(catalog, ident) => | ||
| if (!catalog.isFunctionCatalog) { | ||
| throw new AnalysisException(s"Trying to lookup function '$ident' in " + | ||
| s"catalog '${catalog.name()}', but it is not a FunctionCatalog.") | ||
| } | ||
|
|
||
| val unbound = catalog.asFunctionCatalog.loadFunction(ident) | ||
| val inputType = StructType(arguments.zipWithIndex.map { | ||
| case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) | ||
| }) | ||
| val bound = try { | ||
| unbound.bind(inputType) | ||
| } catch { | ||
| case unsupported: UnsupportedOperationException => | ||
| throw new AnalysisException(s"Function '${unbound.name}' cannot process " + | ||
| s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " + | ||
| unsupported.getMessage, cause = Some(unsupported)) | ||
| } | ||
|
|
||
| bound match { | ||
| case scalarFunc: ScalarFunction[_] => | ||
| processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct, | ||
| filter, ignoreNulls) | ||
| case aggFunc: V2AggregateFunction[_, _] => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto, put into a new method. |
||
| processV2AggregateFunction(aggFunc, arguments, isDistinct, filter, | ||
| ignoreNulls) | ||
| case _ => | ||
| failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" + | ||
| s" or AggregateFunction") | ||
| } | ||
|
|
||
| case _ => u | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def processV2ScalarFunction( | ||
| scalarFunc: ScalarFunction[_], | ||
| inputType: StructType, | ||
| arguments: Seq[Expression], | ||
| isDistinct: Boolean, | ||
| filter: Option[Expression], | ||
| ignoreNulls: Boolean): Expression = { | ||
| if (isDistinct) { | ||
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( | ||
| scalarFunc.name(), "DISTINCT") | ||
| } else if (filter.isDefined) { | ||
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( | ||
| scalarFunc.name(), "FILTER clause") | ||
| } else if (ignoreNulls) { | ||
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( | ||
| scalarFunc.name(), "IGNORE NULLS") | ||
| } else { | ||
| // TODO: implement type coercion by looking at input type from the UDF. We | ||
| // may also want to check if the parameter types from the magic method | ||
| // match the input type through `BoundFunction.inputTypes`. | ||
| val argClasses = inputType.fields.map(_.dataType) | ||
| findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { | ||
| case Some(_) => | ||
| val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) | ||
| Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), | ||
| arguments, returnNullable = scalarFunc.isResultNullable) | ||
| case _ => | ||
| // TODO: handle functions defined in Scala too - in Scala, even if a | ||
| // subclass do not override the default method in parent interface | ||
| // defined in Java, the method can still be found from | ||
| // `getDeclaredMethod`. | ||
| // since `inputType` is a `StructType`, it is mapped to a `InternalRow` | ||
| // which we can use to lookup the `produceResult` method. | ||
| findMethod(scalarFunc, "produceResult", Seq(inputType)) match { | ||
| case Some(_) => | ||
| ApplyFunctionExpression(scalarFunc, arguments) | ||
| case None => | ||
| failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + | ||
| s" magic method nor override 'produceResult'") | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def processV2AggregateFunction( | ||
| aggFunc: V2AggregateFunction[_, _], | ||
| arguments: Seq[Expression], | ||
| isDistinct: Boolean, | ||
| filter: Option[Expression], | ||
| ignoreNulls: Boolean): Expression = { | ||
| if (ignoreNulls) { | ||
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( | ||
| aggFunc.name(), "IGNORE NULLS") | ||
| } | ||
| val aggregator = V2Aggregator(aggFunc, arguments) | ||
| AggregateExpression(aggregator, Complete, isDistinct, filter) | ||
| } | ||
|
|
||
| /** | ||
| * Check if the input `fn` implements the given `methodName` with parameter types specified | ||
| * via `inputType`. | ||
| */ | ||
| private def findMethod( | ||
| fn: BoundFunction, | ||
| methodName: String, | ||
| inputType: Seq[DataType]): Option[Method] = { | ||
| val cls = fn.getClass | ||
| try { | ||
| val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass) | ||
| Some(cls.getDeclaredMethod(methodName, argClasses: _*)) | ||
| } catch { | ||
| case _: NoSuchMethodException => | ||
| None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.