From a8eb29ef1a757d81bacfb896843ed023ab75d52c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 13 Aug 2024 14:21:12 -0400 Subject: [PATCH 1/5] Clean-up UDF code generation --- .../apache/spark/sql/UDFRegistration.scala | 600 +++--------------- .../sql/expressions/UserDefinedFunction.scala | 47 +- .../org/apache/spark/sql/functions.scala | 130 +--- .../spark/sql/internal/ToScalaUDF.scala | 207 ++++++ 4 files changed, 373 insertions(+), 611 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index d0d5beee9945a..e9089ecb8a0bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,21 +20,20 @@ package org.apache.spark.sql import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag -import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.internal.ToScalaUDF import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -50,8 +49,6 @@ import org.apache.spark.util.Utils @Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { - import UDFRegistration._ - protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" @@ -122,14 +119,53 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } } + private def registerScalaUDF( + name: String, + f: AnyRef, + returnTypeTag: TypeTag[_], + inputTypeTags: TypeTag[_]*): UserDefinedFunction = { + register(name, SparkUserDefinedFunction(f, returnTypeTag, inputTypeTags: _*), "scala_udf") + } + + private def registerJavaUDF( + name: String, + f: AnyRef, + returnType: DataType, + cardinality: Int): UserDefinedFunction = { + val validatedReturnType = CharVarcharUtils.failIfHasCharVarchar(returnType) + register(name, SparkUserDefinedFunction(f, validatedReturnType, cardinality), "java_udf") + } + + private def register( + name: String, + udf: SparkUserDefinedFunction, + source: String): UserDefinedFunction = { + val named = udf.withName(name) + val expectedParameterCount = named.inputEncoders.size + val builder: Seq[Expression] => Expression = { children => + val actualParameterCount = children.length + if (expectedParameterCount == actualParameterCount) { + udf.createScalaUDF(children) + } else { + throw QueryCompilationErrors.wrongNumArgsError( + name, + expectedParameterCount.toString, + actualParameterCount) + } + } + functionRegistry.createOrReplaceTempFunction(name, builder, source) + named + } + // scalastyle:off line.size.limit /* register 0-22 were generated by this script (0 to 22).foreach { x => - val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputEncoders = (1 to x).foldRight("Nil")((i, s) => {s"Try(ExpressionEncoder[A$i]()).toOption :: $s"}) + val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") + val typeSeq = "RT" +: (1 to x).map(i => s"A$i") + val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") + val implicitTypeTags = typeSeq.map(t => s"implicitly[TypeTag[$t]]").mkString(", ") println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -137,42 +173,20 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | * @since 1.3.0 | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - | val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders - | val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - | val finalUdf = if (nullable) udf else udf.asNonNullable() - | def builder(e: Seq[Expression]) = if (e.length == $x) { - | finalUdf.createScalaUDF(e) - | } else { - | throw QueryCompilationErrors.wrongNumArgsError(name, "$x", e.length) - | } - | functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - | finalUdf + | registerScalaUDF(name, func, $implicitTypeTags) |}""".stripMargin) } (0 to 22).foreach { i => val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") - val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") - val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" - val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") val version = if (i == 0) "2.3.0" else "1.3.0" - val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { - | val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - | val func = $funcCall - | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - | } else { - | throw QueryCompilationErrors.wrongNumArgsError(name, "$i", e.length) - | } - | functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + | registerJavaUDF(name, ToScalaUDF(f), returnType, $i) |}""".stripMargin) } */ @@ -183,18 +197,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 0) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "0", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]]) } /** @@ -203,18 +206,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 1) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "1", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) } /** @@ -223,18 +215,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 2) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "2", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) } /** @@ -243,18 +224,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 3) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "3", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) } /** @@ -263,18 +233,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 4) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "4", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) } /** @@ -283,18 +242,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 5) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "5", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) } /** @@ -303,18 +251,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 6) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "6", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) } /** @@ -323,18 +260,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 7) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "7", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) } /** @@ -343,18 +269,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 8) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "8", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) } /** @@ -363,18 +278,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 9) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "9", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) } /** @@ -383,18 +287,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 10) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "10", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) } /** @@ -403,18 +296,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 11) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "11", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]]) } /** @@ -423,18 +305,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 12) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "12", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]]) } /** @@ -443,18 +314,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 13) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "13", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]]) } /** @@ -463,18 +323,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 14) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "14", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]]) } /** @@ -483,18 +332,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 15) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "15", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]]) } /** @@ -503,18 +341,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 16) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "16", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]]) } /** @@ -523,18 +350,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 17) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "17", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]]) } /** @@ -543,18 +359,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 18) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "18", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]]) } /** @@ -563,18 +368,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 19) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "19", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]]) } /** @@ -583,18 +377,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 20) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "20", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]]) } /** @@ -603,18 +386,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 21) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "21", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]]) } /** @@ -623,18 +395,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 22) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "22", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]], implicitly[TypeTag[A22]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -733,14 +494,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = () => f.asInstanceOf[UDF0[Any]].call() - def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "0", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 0) } /** @@ -748,14 +502,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "1", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 1) } /** @@ -763,14 +510,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "2", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 2) } /** @@ -778,14 +518,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "3", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 3) } /** @@ -793,14 +526,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "4", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 4) } /** @@ -808,14 +534,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "5", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 5) } /** @@ -823,14 +542,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "6", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 6) } /** @@ -838,14 +550,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "7", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 7) } /** @@ -853,14 +558,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "8", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 8) } /** @@ -868,14 +566,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "9", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 9) } /** @@ -883,14 +574,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "10", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 10) } /** @@ -898,14 +582,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "11", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 11) } /** @@ -913,14 +590,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "12", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 12) } /** @@ -928,14 +598,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "13", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 13) } /** @@ -943,14 +606,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "14", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 14) } /** @@ -958,14 +614,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "15", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 15) } /** @@ -973,14 +622,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "16", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 16) } /** @@ -988,14 +630,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "17", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 17) } /** @@ -1003,14 +638,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "18", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 18) } /** @@ -1018,14 +646,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "19", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 19) } /** @@ -1033,14 +654,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "20", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 20) } /** @@ -1048,14 +662,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "21", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 21) } /** @@ -1063,30 +670,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "22", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + registerJavaUDF(name, ToScalaUDF(f), returnType, 22) } // scalastyle:on line.size.limit } - -private[sql] object UDFRegistration { - /** - * Obtaining the schema of output encoder for `ScalaUDF`. - * - * As the serialization in `ScalaUDF` is for individual column, not the whole row, - * we just take the data type of vanilla object serializer, not `serializer` which - * is transformed somehow for top-level row. - */ - def outputSchema(outputEncoder: ExpressionEncoder[_]): ScalaReflection.Schema = { - ScalaReflection.Schema(outputEncoder.objSerializer.dataType, - outputEncoder.objSerializer.nullable) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index a75384fb0f4e0..39ff44126b3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.expressions +import scala.reflect.runtime.universe.TypeTag +import scala.util.Try + import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaAggregator import org.apache.spark.sql.types.DataType @@ -89,8 +93,8 @@ sealed abstract class UserDefinedFunction { private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, - inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, - outputEncoder: Option[ExpressionEncoder[_]] = None, + inputEncoders: Seq[Option[Encoder[_]]] = Nil, + outputEncoder: Option[Encoder[_]] = None, name: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { @@ -105,8 +109,8 @@ private[spark] case class SparkUserDefinedFunction( f, dataType, exprs, - inputEncoders, - outputEncoder, + inputEncoders.map(_.map(e => encoderFor(e))), + outputEncoder.map(e => encoderFor(e)), udfName = name, nullable = nullable, udfDeterministic = deterministic) @@ -133,6 +137,35 @@ private[spark] case class SparkUserDefinedFunction( } } +object SparkUserDefinedFunction { + private[sql] def apply( + function: AnyRef, + returnTypeTag: TypeTag[_], + inputTypeTags: TypeTag[_]*): SparkUserDefinedFunction = { + val outputEncoder = ScalaReflection.encoderFor(returnTypeTag) + val inputEncoders = inputTypeTags.map { tag => + Try(ScalaReflection.encoderFor(tag)).toOption + } + SparkUserDefinedFunction( + f = function, + inputEncoders = inputEncoders, + dataType = outputEncoder.dataType, + outputEncoder = Option(outputEncoder), + nullable = outputEncoder.nullable) + } + + private[sql] def apply( + function: AnyRef, + returnType: DataType, + cardinality: Int): SparkUserDefinedFunction = { + SparkUserDefinedFunction( + function, + returnType, + inputEncoders = Seq.fill(cardinality)(None), + None) + } +} + private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( aggregator: Aggregator[IN, BUF, OUT], inputEncoder: Encoder[IN], @@ -147,8 +180,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( // This is also used by udf.register(...) when it detects a UserDefinedAggregator def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { - val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]] - val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]] + val iEncoder = encoderFor(inputEncoder) + val bEncoder = encoderFor(aggregator.bufferEncoder) ScalaAggregator( exprs, aggregator, iEncoder, bEncoder, nullable, deterministic, aggregatorName = name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f0667ba94a4ec..fc9e5f3e8f72a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,19 +21,18 @@ import java.util.Collections import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag -import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.ScalaReflection.encoderFor import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, ToScalaUDF} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -432,7 +431,7 @@ object functions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(ExpressionEncoder[Long]()) + count(Column(columnName)).as(AgnosticEncoders.PrimitiveLongEncoder) /** * Aggregate function: returns the number of distinct items in a group. @@ -7899,9 +7898,10 @@ object functions { /* Use the following code to generate: (0 to 10).foreach { x => - val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputEncoders = (1 to x).foldRight("Nil")((i, s) => {s"Try(ExpressionEncoder[A$i]()).toOption :: $s"}) + val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") + val typeSeq = "RT" +: (1 to x).map(i => s"A$i") + val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") + val implicitTypeTags = typeSeq.map(t => s"implicitly[TypeTag[$t]]").mkString(", ") println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -7913,20 +7913,12 @@ object functions { | * @since 1.3.0 | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - | val inputEncoders = $inputEncoders - | val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - | if (nullable) udf else udf.asNonNullable() + | SparkUserDefinedFunction(f, $implicitTypeTags) |}""".stripMargin) } (0 to 10).foreach { i => val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") - val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") - val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" - val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") - val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Defines a Java UDF$i instance as user-defined function (UDF). @@ -7938,8 +7930,7 @@ object functions { | * @since 2.3.0 | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { - | val func = $funcCall - | SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill($i)(None)) + | SparkUserDefinedFunction(ToScalaUDF(f), returnType, $i) |}""".stripMargin) } @@ -7975,7 +7966,7 @@ object functions { * @note The input encoder is inferred from the input type IN. */ def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = { - udaf(agg, ExpressionEncoder[IN]()) + udaf(agg, encoderFor[IN]) } /** @@ -8022,11 +8013,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]]) } /** @@ -8039,11 +8026,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) } /** @@ -8056,11 +8039,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) } /** @@ -8073,11 +8052,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) } /** @@ -8090,11 +8065,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) } /** @@ -8107,11 +8078,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) } /** @@ -8124,11 +8091,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) } /** @@ -8141,11 +8104,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) } /** @@ -8158,11 +8117,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) } /** @@ -8175,11 +8130,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) } /** @@ -8192,11 +8143,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -8213,8 +8160,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { - val func = () => f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(0)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 0) } /** @@ -8227,8 +8173,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(1)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 1) } /** @@ -8241,8 +8186,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(2)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 2) } /** @@ -8255,8 +8199,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(3)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 3) } /** @@ -8269,8 +8212,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(4)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 4) } /** @@ -8283,8 +8225,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(5)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 5) } /** @@ -8297,8 +8238,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(6)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 6) } /** @@ -8311,8 +8251,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(7)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 7) } /** @@ -8325,8 +8264,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(8)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 8) } /** @@ -8339,8 +8277,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(9)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 9) } /** @@ -8353,8 +8290,7 @@ object functions { * @since 2.3.0 */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(10)(None)) + SparkUserDefinedFunction(ToScalaUDF(f), returnType, 10) } // scalastyle:on parameter.number diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala new file mode 100644 index 0000000000000..6f5df1397aa62 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.sql.api.java._ + +/** + * Helper class that provided conversions from org.apache.spark.sql.api.java.Function* to + * scala.Function*. + */ +private[sql] object ToScalaUDF { + // scalastyle:off line.size.limit + + /* register 0-22 were generated by this script + + (0 to 22).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" + println(s""" + |/** + | * Create a scala.Function$i wrapper for a org.apache.spark.sql.api.java.UDF$i instance. + | */ + |def apply(f: UDF$i[$extTypeArgs]): AnyRef = { + | $funcCall + |}""".stripMargin) + } + */ + + /** + * Create a scala.Function0 wrapper for a org.apache.spark.sql.api.java.UDF0 instance. + */ + def apply(f: UDF0[_]): AnyRef = { + () => f.asInstanceOf[UDF0[Any]].call() + } + + /** + * Create a scala.Function1 wrapper for a org.apache.spark.sql.api.java.UDF1 instance. + */ + def apply(f: UDF1[_, _]): AnyRef = { + f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + } + + /** + * Create a scala.Function2 wrapper for a org.apache.spark.sql.api.java.UDF2 instance. + */ + def apply(f: UDF2[_, _, _]): AnyRef = { + f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + } + + /** + * Create a scala.Function3 wrapper for a org.apache.spark.sql.api.java.UDF3 instance. + */ + def apply(f: UDF3[_, _, _, _]): AnyRef = { + f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function4 wrapper for a org.apache.spark.sql.api.java.UDF4 instance. + */ + def apply(f: UDF4[_, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function5 wrapper for a org.apache.spark.sql.api.java.UDF5 instance. + */ + def apply(f: UDF5[_, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function6 wrapper for a org.apache.spark.sql.api.java.UDF6 instance. + */ + def apply(f: UDF6[_, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function7 wrapper for a org.apache.spark.sql.api.java.UDF7 instance. + */ + def apply(f: UDF7[_, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function8 wrapper for a org.apache.spark.sql.api.java.UDF8 instance. + */ + def apply(f: UDF8[_, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function9 wrapper for a org.apache.spark.sql.api.java.UDF9 instance. + */ + def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function10 wrapper for a org.apache.spark.sql.api.java.UDF10 instance. + */ + def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function11 wrapper for a org.apache.spark.sql.api.java.UDF11 instance. + */ + def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function12 wrapper for a org.apache.spark.sql.api.java.UDF12 instance. + */ + def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function13 wrapper for a org.apache.spark.sql.api.java.UDF13 instance. + */ + def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function14 wrapper for a org.apache.spark.sql.api.java.UDF14 instance. + */ + def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function15 wrapper for a org.apache.spark.sql.api.java.UDF15 instance. + */ + def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function16 wrapper for a org.apache.spark.sql.api.java.UDF16 instance. + */ + def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function17 wrapper for a org.apache.spark.sql.api.java.UDF17 instance. + */ + def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function18 wrapper for a org.apache.spark.sql.api.java.UDF18 instance. + */ + def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function19 wrapper for a org.apache.spark.sql.api.java.UDF19 instance. + */ + def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function20 wrapper for a org.apache.spark.sql.api.java.UDF20 instance. + */ + def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function21 wrapper for a org.apache.spark.sql.api.java.UDF21 instance. + */ + def apply(f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + + /** + * Create a scala.Function22 wrapper for a org.apache.spark.sql.api.java.UDF22 instance. + */ + def apply(f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { + f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + } + // scalastyle:on line.size.limit +} From d9d4f74c5e0433659f3642344a66024c4c238590 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 13 Aug 2024 14:28:39 -0400 Subject: [PATCH 2/5] Typo --- .../main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala index 6f5df1397aa62..bf82e9def5ef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.sql.api.java._ /** - * Helper class that provided conversions from org.apache.spark.sql.api.java.Function* to + * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* to * scala.Function*. */ private[sql] object ToScalaUDF { From b2d153c2808032808f013dcdb4d63d064ccb1fa7 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 13 Aug 2024 15:45:19 -0400 Subject: [PATCH 3/5] Let's not forget connect... --- .../sql/connect/planner/SparkConnectPlanner.scala | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 43b300a11a49d..25564ec8d5c51 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1899,15 +1899,7 @@ class SparkConnectPlanner( fun: org.apache.spark.sql.expressions.UserDefinedFunction, exprs: Seq[Expression]): ScalaUDF = { val f = fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction] - ScalaUDF( - function = f.f, - dataType = f.dataType, - children = exprs, - inputEncoders = f.inputEncoders, - outputEncoder = f.outputEncoder, - udfName = f.name, - nullable = f.nullable, - udfDeterministic = f.deterministic) + f.createScalaUDF(exprs) } private def extractProtobufArgs(children: Seq[Expression]) = { From a56d2595a3725d1ae30abda3535cad3d578557f2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 13 Aug 2024 20:23:42 -0400 Subject: [PATCH 4/5] Bugs --- .../connect/client/CheckConnectJvmClientCompatibility.scala | 5 ++++- .../main/scala/org/apache/spark/sql/UDFRegistration.scala | 2 +- .../apache/spark/sql/expressions/UserDefinedFunction.scala | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 948211e9d1f7a..59b399da9a5c6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -196,7 +196,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.ExtendedExplainGenerator"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataSourceRegistration"), // DataFrame Reader & Writer @@ -297,6 +296,10 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.artifact.util.ArtifactUtils$"), + // UDFRegistration + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.register"), + // Datasource V2 partition transforms ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index e9089ecb8a0bf..e5999355133e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -145,7 +145,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val builder: Seq[Expression] => Expression = { children => val actualParameterCount = children.length if (expectedParameterCount == actualParameterCount) { - udf.createScalaUDF(children) + named.createScalaUDF(children) } else { throw QueryCompilationErrors.wrongNumArgsError( name, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 39ff44126b3be..d18deff70a646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -23,6 +23,7 @@ import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaAggregator @@ -109,7 +110,7 @@ private[spark] case class SparkUserDefinedFunction( f, dataType, exprs, - inputEncoders.map(_.map(e => encoderFor(e))), + inputEncoders.map(_.filter(_ != UnboundRowEncoder).map(e => encoderFor(e))), outputEncoder.map(e => encoderFor(e)), udfName = name, nullable = nullable, From 5232008568267e20fe39d0b7fb17cc7621513d83 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 13 Aug 2024 20:38:23 -0400 Subject: [PATCH 5/5] Fix after merge --- .../spark/sql/internal/columnNodeSupport.scala | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 0fa885b34f9fb..4d4960d24d010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -20,7 +20,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -176,15 +176,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres ScalaUDAF(udaf = a, children = arguments.map(apply)).toAggregateExpression(isDistinct) case InvokeInlineUserDefinedFunction(udf: SparkUserDefinedFunction, arguments, _, _) => - ScalaUDF( - function = udf.f, - dataType = udf.dataType, - children = arguments.map(apply), - inputEncoders = udf.inputEncoders, - outputEncoder = udf.outputEncoder, - udfName = udf.givenName, - nullable = udf.deterministic, - udfDeterministic = udf.deterministic) + udf.createScalaUDF(arguments.map(apply)) case Wrapper(expression, _) => expression