Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, TypedAggUtils}
import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, TypedAggUtils, UserDefinedFunctionUtils}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -1722,34 +1722,36 @@ class SparkConnectPlanner(
}

/**
* Translates a Scala user-defined function from proto to the Catalyst expression.
* Translates a Scala user-defined function or aggregator from proto to the corresponding
* Catalyst expression.
*
* @param fun
* Proto representation of the Scala user-defined function.
* Proto representation of the Scala user-defined function or aggregator.
* @return
* ScalaUDF.
* An expression, either a ScalaUDF or a ScalaAggregator.
*/
private def transformScalaUDF(fun: proto.CommonInlineUserDefinedFunction): Expression = {
val udf = fun.getScalarScalaUdf
val udfPacket = unpackUdf(fun)
if (udf.getAggregate) {
ScalaAggregator(
transformScalaFunction(fun).asInstanceOf[UserDefinedAggregator[Any, Any, Any]],
fun.getArgumentsList.asScala.map(transformExpression).toSeq)
.toAggregateExpression()
} else {
ScalaUDF(
function = udfPacket.function,
dataType = transformDataType(udf.getOutputType),
children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
inputEncoders = udfPacket.inputEncoders.map(e => Try(ExpressionEncoder(e)).toOption),
outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
udfName = Option(fun.getFunctionName),
nullable = udf.getNullable,
udfDeterministic = fun.getDeterministic)
val children = fun.getArgumentsList.asScala.map(transformExpression).toSeq
transformScalaFunction(fun) match {
case udf: SparkUserDefinedFunction =>
UserDefinedFunctionUtils.toScalaUDF(udf, children)
case uda: UserDefinedAggregator[_, _, _] =>
ScalaAggregator(uda, children).toAggregateExpression()
case other =>
throw InvalidPlanInput(
s"Unsupported UserDefinedFunction implementation: ${other.getClass}")
}
}

/**
* Translates a Scala user-defined function or aggregator. from proto to a UserDefinedFunction.
*
* @param fun
* Proto representation of the Scala user-defined function or aggregator.
* @return
* A concrete UserDefinedFunction implementation, either a SparkUserDefinedFunction or a
* UserDefinedAggregator.
*/
private def transformScalaFunction(
fun: proto.CommonInlineUserDefinedFunction): UserDefinedFunction = {
val udf = fun.getScalarScalaUdf
Expand Down