From d2ea80dd80939eef85e3eb966ca648b194a49bad Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 23 Jul 2024 21:00:02 -0400 Subject: [PATCH 01/18] Integrate ColumnNode AST into Column.scala --- .../org/apache/spark/sql/avro/functions.scala | 8 +- .../apache/spark/sql/protobuf/functions.scala | 20 +- .../spark/ml/feature/StringIndexer.scala | 3 +- .../org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 184 +++++++++--------- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 6 +- .../spark/sql/expressions/WindowSpec.scala | 2 +- .../org/apache/spark/sql/functions.scala | 92 ++++----- .../spark/sql/DataFrameComplexTypeSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- .../sql/streaming/StreamingQuerySuite.scala | 9 +- 12 files changed, 152 insertions(+), 186 deletions(-) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala index 5830b2ec42383..1af7558200de3 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala @@ -41,7 +41,7 @@ object functions { def from_avro( data: Column, jsonFormatSchema: String): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) + Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) } /** @@ -62,7 +62,7 @@ object functions { data: Column, jsonFormatSchema: String, options: java.util.Map[String, String]): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) + Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) } /** @@ -74,7 +74,7 @@ object functions { */ @Experimental def to_avro(data: Column): Column = { - new Column(CatalystDataToAvro(data.expr, None)) + Column(CatalystDataToAvro(data.expr, None)) } /** @@ -87,6 +87,6 @@ object functions { */ @Experimental def to_avro(data: Column, jsonFormatSchema: String): Column = { - new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) + Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) } } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 91e87dee50482..2700764399606 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -70,7 +70,7 @@ object functions { messageName: String, binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String]): Column = { - new Column( + Column( ProtobufDataToCatalyst( data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap ) @@ -93,7 +93,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent))) + Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent))) } /** @@ -112,7 +112,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet))) + Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet))) } /** @@ -132,7 +132,7 @@ object functions { */ @Experimental def from_protobuf(data: Column, messageClassName: String): Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageClassName)) + Column(ProtobufDataToCatalyst(data.expr, messageClassName)) } /** @@ -156,7 +156,7 @@ object functions { data: Column, messageClassName: String, options: java.util.Map[String, String]): Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap)) + Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap)) } /** @@ -194,7 +194,7 @@ object functions { @Experimental def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - new Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet))) + Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet))) } /** * Converts a column into binary of protobuf format. The Protobuf definition is provided @@ -216,7 +216,7 @@ object functions { descFilePath: String, options: java.util.Map[String, String]): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - new Column( + Column( CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap) ) } @@ -242,7 +242,7 @@ object functions { binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String] ): Column = { - new Column( + Column( CatalystDataToProtobuf( data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap ) @@ -266,7 +266,7 @@ object functions { */ @Experimental def to_protobuf(data: Column, messageClassName: String): Column = { - new Column(CatalystDataToProtobuf(data.expr, messageClassName)) + Column(CatalystDataToProtobuf(data.expr, messageClassName)) } /** @@ -288,6 +288,6 @@ object functions { @Experimental def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) : Column = { - new Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap)) + Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 94d4fa6fe6f20..cfc6b8e0e5ac6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -30,7 +30,6 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -199,7 +198,7 @@ class StringIndexer @Since("1.4.0") ( } else { // We don't count for NaN values. Because `StringIndexerAggregator` only processes strings, // we replace NaNs with null in advance. - new Column(If(col.isNaN.expr, Literal(null), col.expr)).cast(StringType) + nanvl(col, lit(null)).cast(StringType) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 7a27b32aa24c5..9388205a751ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -257,7 +257,7 @@ private[ml] class SummaryBuilderImpl( mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - new Column(agg.toAggregateExpression()) + Column(agg.toAggregateExpression()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index dc804a72ad93e..d5bbb57786486 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -27,22 +27,26 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.TypedAggUtils +import org.apache.spark.sql.internal.{ColumnNode, Extension, TypedAggUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ -private[sql] object Column { +private[spark] object Column { def apply(colName: String): Column = new Column(colName) - def apply(expr: Expression): Column = new Column(expr) + // TODO move this to a separate class! + // Move as much as we can to the new API + // Create internal util for create expression(nodes) + def apply(expr: Expression): Column = Column(Extension(expr)) - def unapply(col: Column): Option[Expression] = Some(col.expr) + def apply(node: => ColumnNode): Column = withOrigin(new Column(node)) + // TODO move this else where... private[sql] def generateAlias(e: Expression): String = { e match { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => @@ -51,6 +55,7 @@ private[sql] object Column { } } + // TODO move this else where... private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { val metadataWithoutId = new MetadataBuilder() .withMetadata(a.metadata) @@ -77,9 +82,9 @@ private[sql] object Column { isDistinct: Boolean, isInternal: Boolean, inputs: Seq[Column]): Column = withOrigin { - Column(UnresolvedFunction( - name :: Nil, - inputs.map(_.expr), + Column(internal.UnresolvedFunction( + name, + inputs.map(_.node), isDistinct = isDistinct, isInternal = isInternal)) } @@ -97,9 +102,13 @@ private[sql] object Column { */ @Stable class TypedColumn[-T, U]( - expr: Expression, + node: ColumnNode, private[sql] val encoder: ExpressionEncoder[U]) - extends Column(expr) { + extends Column(node) { + + // TODO get rid of this. + // This requires one or two more ColumnNodes... + def this(expr: Expression, encoder: ExpressionEncoder[U]) = this(Extension(expr), encoder) /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -121,7 +130,7 @@ class TypedColumn[-T, U]( * @since 2.0.0 */ override def name(alias: String): TypedColumn[T, U] = - new TypedColumn[T, U](super.name(alias).expr, encoder) + new TypedColumn[T, U](super.name(alias).node, encoder) } @@ -145,9 +154,6 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * @note The internal Catalyst expression can be accessed via [[expr]], but this method is for - * debugging purposes only and can change in any future Spark releases. - * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators * @groupname df_ops DataFrame functions @@ -156,15 +162,16 @@ class TypedColumn[-T, U]( * @since 1.3.0 */ @Stable -class Column(val expr: Expression) extends Logging { +class Column(val node: ColumnNode) extends Logging { + // TODO this will be moved to the calling classes. + // We must, must, must move all user facing use cases. + lazy val expr: Expression = internal.ColumnNodeToExpressionConverter(node) def this(name: String) = this(withOrigin { name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => - val parts = UnresolvedAttribute.parseAttributeName(name.dropRight(2)) - UnresolvedStar(Some(parts)) - case _ => UnresolvedAttribute.quotedString(name) + case "*" => internal.UnresolvedStar(None) + case _ if name.endsWith(".*") => internal.UnresolvedStar(Option(name.dropRight(2))) + case _ => internal.UnresolvedAttribute(name) } }) @@ -181,24 +188,16 @@ class Column(val expr: Expression) extends Logging { override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { - case that: Column => that.normalizedExpr() == this.normalizedExpr() + case that: Column => that.node == this.node case _ => false } - override def hashCode: Int = this.normalizedExpr().hashCode() - - private def normalizedExpr(): Expression = expr transform { - case a: AttributeReference => Column.stripColumnReferenceMetadata(a) - } - - /** Creates a column based on the given expression. */ - private def withExpr(newExpr: => Expression): Column = withOrigin { - new Column(newExpr) - } + override def hashCode: Int = this.node.hashCode() /** * Returns the expression for this column either with an existing or auto assigned name. */ + // TODO move this elsewhere. private[sql] def named: NamedExpression = expr match { case expr: NamedExpression => expr @@ -225,7 +224,7 @@ class Column(val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, encoderFor[U]) /** * Extracts a value or values from a complex type. @@ -240,8 +239,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def apply(extraction: Any): Column = withExpr { - UnresolvedExtractValue(expr, lit(extraction).expr) + def apply(extraction: Any): Column = Column { + internal.UnresolvedExtractValue(node, lit(extraction).node) } /** @@ -291,14 +290,18 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ def ===(other: Any): Column = { - val right = lit(other).expr - if (this.expr == right) { + val right = lit(other) + checkTrivialPredicate(right) + fn("=", other) + } + + private def checkTrivialPredicate(right: Column): Unit = { + if (this == right) { logWarning( log"Constructing trivially true equals predicate, " + - log"'${MDC(LEFT_EXPR, this.expr)} = ${MDC(RIGHT_EXPR, right)}'. " + + log"'${MDC(LEFT_EXPR, this)} <=> ${MDC(RIGHT_EXPR, right)}'. " + log"Perhaps you need to use aliases.") } - fn("=", other) } /** @@ -498,14 +501,9 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ def <=>(other: Any): Column = { - val right = lit(other).expr - if (this.expr == right) { - logWarning( - log"Constructing trivially true equals predicate, " + - log"'${MDC(LEFT_EXPR, this.expr)} <=> ${MDC(RIGHT_EXPR, right)}'. " + - log"Perhaps you need to use aliases.") - } - fn("<=>", other) + val right = lit(other) + checkTrivialPredicate(right) + fn("<=>", right) } /** @@ -537,11 +535,11 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = withExpr { - this.expr match { - case CaseWhen(branches, None) => - CaseWhen(branches :+ ((condition.expr, lit(value).expr))) - case CaseWhen(_, Some(_)) => + def when(condition: Column, value: Any): Column = Column { + node match { + case internal.CaseWhenOtherwise(branches, None, _) => + internal.CaseWhenOtherwise(branches :+ ((condition.node, lit(value).node)), None) + case internal.CaseWhenOtherwise(_, Some(_), _) => throw new IllegalArgumentException( "when() cannot be applied once otherwise() is applied") case _ => @@ -571,11 +569,11 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def otherwise(value: Any): Column = withExpr { - this.expr match { - case CaseWhen(branches, None) => - CaseWhen(branches, Option(lit(value).expr)) - case CaseWhen(_, Some(_)) => + def otherwise(value: Any): Column = Column { + node match { + case internal.CaseWhenOtherwise(branches, None, _) => + internal.CaseWhenOtherwise(branches, Option(lit(value).node)) + case internal.CaseWhenOtherwise(_, Some(_), _) => throw new IllegalArgumentException( "otherwise() can only be applied once on a Column previously generated by when()") case _ => @@ -951,10 +949,10 @@ class Column(val expr: Expression) extends Logging { * @since 3.1.0 */ // scalastyle:on line.size.limit - def withField(fieldName: String, col: Column): Column = withExpr { + def withField(fieldName: String, col: Column): Column = { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") - UpdateFields(expr, fieldName, col.expr) + Column(internal.UpdateFields(node, fieldName, Option(col.node))) } // scalastyle:off line.size.limit @@ -1017,9 +1015,9 @@ class Column(val expr: Expression) extends Logging { * @since 3.1.0 */ // scalastyle:on line.size.limit - def dropFields(fieldNames: String*): Column = withExpr { - fieldNames.tail.foldLeft(UpdateFields(expr, fieldNames.head)) { - (resExpr, fieldName) => UpdateFields(resExpr, fieldName) + def dropFields(fieldNames: String*): Column = Column { + fieldNames.tail.foldLeft(internal.UpdateFields(node, fieldNames.head)) { + (resExpr, fieldName) => internal.UpdateFields(resExpr, fieldName) } } @@ -1129,7 +1127,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Seq[String]): Column = withExpr { MultiAlias(expr, aliases) } + def as(aliases: Seq[String]): Column = Column(internal.Alias(node, aliases)) /** * Assigns the given aliases to the results of a table generating function. @@ -1141,9 +1139,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Array[String]): Column = withExpr { - MultiAlias(expr, aliases.toImmutableArraySeq) - } + def as(aliases: Array[String]): Column = as(aliases.toImmutableArraySeq) /** * Gives the column an alias. @@ -1171,9 +1167,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String, metadata: Metadata): Column = withExpr { - Alias(expr, alias)(explicitMetadata = Some(metadata)) - } + def as(alias: String, metadata: Metadata): Column = + Column(internal.Alias(node, alias :: Nil, metadata = Option(metadata))) /** * Gives the column a name (alias). @@ -1189,13 +1184,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.0.0 */ - def name(alias: String): Column = withExpr { - // SPARK-33536: an alias is no longer a column reference. Therefore, - // we should not inherit the column reference related metadata in an alias - // so that it is not caught as a column reference in DetectAmbiguousSelfJoin. - Alias(expr, alias)( - nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) - } + def name(alias: String): Column = Column(internal.Alias(node, alias :: Nil)) /** * Casts the column to a different data type. @@ -1211,11 +1200,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { - val cast = Cast(expr, CharVarcharUtils.replaceCharVarcharWithStringForCast(to)) - cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) - cast - } + def cast(to: DataType): Column = Column(internal.Cast(node, to)) /** * Casts the column to a different data type, using the canonical string representation @@ -1245,13 +1230,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 4.0.0 */ - def try_cast(to: DataType): Column = withExpr { - val cast = Cast( - child = expr, - dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(to), - evalMode = EvalMode.TRY) - cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) - cast + def try_cast(to: DataType): Column = { + Column(internal.Cast(node, to, Option(internal.Cast.EvalMode.Try))) } /** @@ -1268,6 +1248,12 @@ class Column(val expr: Expression) extends Logging { try_cast(CatalystSqlParser.parseDataType(to)) } + private def sortOrder( + sortDirection: internal.SortOrder.SortDirection.Value, + nullOrdering: internal.SortOrder.NullOrdering.Value): Column = { + Column(internal.SortOrder(node, sortDirection, nullOrdering)) + } + /** * Returns a sort expression based on the descending order of the column. * {{{ @@ -1281,7 +1267,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def desc: Column = withExpr { SortOrder(expr, Descending) } + def desc: Column = desc_nulls_last /** * Returns a sort expression based on the descending order of the column, @@ -1297,7 +1283,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Seq.empty) } + def desc_nulls_first: Column = sortOrder( + internal.SortOrder.SortDirection.Descending, + internal.SortOrder.NullOrdering.NullsFirst) /** * Returns a sort expression based on the descending order of the column, @@ -1313,7 +1301,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Seq.empty) } + def desc_nulls_last: Column = sortOrder( + internal.SortOrder.SortDirection.Descending, + internal.SortOrder.NullOrdering.NullsLast) /** * Returns a sort expression based on ascending order of the column. @@ -1328,7 +1318,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def asc: Column = withExpr { SortOrder(expr, Ascending) } + def asc: Column = asc_nulls_first /** * Returns a sort expression based on ascending order of the column, @@ -1344,7 +1334,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Seq.empty) } + def asc_nulls_first: Column = sortOrder( + internal.SortOrder.SortDirection.Ascending, + internal.SortOrder.NullOrdering.NullsFirst) /** * Returns a sort expression based on ascending order of the column, @@ -1360,7 +1352,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Seq.empty) } + def asc_nulls_last: Column = sortOrder( + internal.SortOrder.SortDirection.Ascending, + internal.SortOrder.NullOrdering.NullsLast) /** * Prints the expression to the console for debugging purposes. @@ -1371,9 +1365,9 @@ class Column(val expr: Expression) extends Logging { def explain(extended: Boolean): Unit = { // scalastyle:off println if (extended) { - println(expr) + println(node) } else { - println(expr.sql) + println(expr.sql) // TODO (need to add this to the nodes)! } // scalastyle:on println } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index c083ee89db6f2..231d361810f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -468,7 +468,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) + Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 6b3b374ae9ad9..d059f5ada576b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode} +import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors @@ -205,7 +205,7 @@ object StatFunctions extends Logging { val column = col(field.name) var casted = column if (field.dataType.isInstanceOf[StringType]) { - casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY)) + casted = column.try_cast(DoubleType) } val percentilesCol = if (percentiles.nonEmpty) { @@ -252,7 +252,7 @@ object StatFunctions extends Logging { .withColumnRenamed("_1", "summary") } else { val valueColumns = columnNames.map { columnName => - new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) + Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) } import org.apache.spark.util.ArrayImplicits._ ds.select(mapColumns: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 32aa13a29cec3..7b4ef3ba9ecd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -215,6 +215,6 @@ class WindowSpec private[sql]( */ private[sql] def withAggregate(aggregate: Column): Column = { val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) - new Column(WindowExpression(aggregate.expr, spec)) + Column(WindowExpression(aggregate.expr, spec)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f0667ba94a4ec..89a873deb7985 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,12 +26,8 @@ import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -89,10 +85,6 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: => Expression): Column = withOrigin { - Column(expr) - } - /** * Returns a [[Column]] based on the given column name. * @@ -129,7 +121,7 @@ object functions { // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. // This is significantly better when there are many threads calling `lit` concurrently. - Column(Literal(literal)) + Column(internal.Literal(literal)) } } @@ -141,7 +133,7 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = withOrigin { + def typedLit[T : TypeTag](literal: T): Column = { typedlit(literal) } @@ -160,11 +152,13 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = withOrigin { + def typedlit[T : TypeTag](literal: T): Column = { literal match { case c: Column => c case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) + case _ => + val dataType = ScalaReflection.schemaFor[T].dataType + Column(internal.Literal(literal, Option(dataType))) } } @@ -410,6 +404,8 @@ object functions { corr(Column(columnName1), Column(columnName2)) } + private val ONE = Column(internal.Literal(1, Option(IntegerType))) + /** * Aggregate function: returns the number of items in a group. * @@ -417,9 +413,9 @@ object functions { * @since 1.3.0 */ def count(e: Column): Column = { - val withoutStar = e.expr match { + val withoutStar = e.node match { // Turn count(*) into count(1) - case _: Star => Column(Literal(1)) + case internal.UnresolvedStar(None, _, _) => ONE case _ => e } Column.fn("count", withoutStar) @@ -1714,10 +1710,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, HintInfo(strategy = Some(BROADCAST))))(df.exprEnc) - } + def broadcast[T](df: Dataset[T]): Dataset[T] = df.hint("broadcast") /** * Returns the first column that is not null, or null if all inputs are null. @@ -2013,9 +2006,8 @@ object functions { * @group conditional_funcs * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = withExpr { - CaseWhen(Seq((condition.expr, lit(value).expr))) - } + def when(condition: Column, value: Any): Column = + Column(internal.CaseWhenOtherwise(Seq(condition.node -> lit(value).node))) /** * Computes bitwise NOT (~) of a number. @@ -2073,12 +2065,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = withExpr { - val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser() - } - parser.parseExpression(expr) - } + def expr(expr: String): Column = Column(internal.SqlExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions @@ -6066,31 +6053,26 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = withOrigin { - Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val function = f(Column(x)).expr - LambdaFunction(function, Seq(x)) - } + + private def createLambda(f: Column => Column) = { + val x = internal.UnresolvedNamedLambdaVariable("x") + val function = f(Column(x)).node + Column(internal.LambdaFunction(function, Seq(x))) } - private def createLambda(f: (Column, Column) => Column) = withOrigin { - Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) - } + private def createLambda(f: (Column, Column) => Column) = { + val x = internal.UnresolvedNamedLambdaVariable("x") + val y = internal.UnresolvedNamedLambdaVariable("y") + val function = f(Column(x), Column(y)).node + Column(internal.LambdaFunction(function, Seq(x, y))) } - private def createLambda(f: (Column, Column, Column) => Column) = withOrigin { - Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) - val function = f(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) - } + private def createLambda(f: (Column, Column, Column) => Column) = { + val x = internal.UnresolvedNamedLambdaVariable("x") + val y = internal.UnresolvedNamedLambdaVariable("y") + val z = internal.UnresolvedNamedLambdaVariable("z") + val function = f(Column(x), Column(y), Column(z)).node + Column(internal.LambdaFunction(function, Seq(x, y, z))) } /** @@ -8396,7 +8378,7 @@ object functions { */ @scala.annotation.varargs @deprecated("Use call_udf") - def callUDF(udfName: String, cols: Column*): Column = call_udf(udfName, cols: _*) + def callUDF(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** * Call an user-defined function. @@ -8414,7 +8396,7 @@ object functions { * @since 3.2.0 */ @scala.annotation.varargs - def call_udf(udfName: String, cols: Column*): Column = Column.fn(udfName, cols: _*) + def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** * Call a SQL function. @@ -8427,15 +8409,7 @@ object functions { */ @scala.annotation.varargs def call_function(funcName: String, cols: Column*): Column = { - val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser() - } - val nameParts = parser.parseMultipartIdentifier(funcName) - call_function(nameParts, cols: _*) - } - - private def call_function(nameParts: Seq[String], cols: Column*): Column = withExpr { - UnresolvedFunction(nameParts, cols.map(_.expr), false) + Column(internal.UnresolvedFunction(funcName, cols.map(_.node), isUserDefinedFunction = true)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index d982a000ad374..48ac2cc5d4044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -82,8 +82,8 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { // items: Seq[Int] => items.map { item => Seq(Struct(item)) } val result = df.select( - new Column(MapObjects( - (item: Expression) => array(struct(new Column(item))).expr, + Column(MapObjects( + (item: Expression) => array(struct(Column(item))).expr, $"items".expr, df.schema("items").dataType.asInstanceOf[ArrayType].elementType )) as "items" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 155acc98cb33b..301ab28b9124b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation} @@ -1566,7 +1566,7 @@ class DataFrameSuite extends QueryTest test("SPARK-46794: exclude subqueries from LogicalRDD constraints") { withTempDir { checkpointDir => val subquery = - new Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) + Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) val df = spark.range(1000).filter($"id" === subquery) assert(df.logicalPlan.constraints.exists(_.exists(_.isInstanceOf[ScalarSubquery]))) @@ -1839,7 +1839,7 @@ class DataFrameSuite extends QueryTest } test("Uuid expressions should produce same results at retries in the same DataFrame") { - val df = spark.range(1).select($"id", new Column(Uuid())) + val df = spark.range(1).select($"id", uuid()) checkAnswer(df, df.collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 061b353879d14..dc0e30f7ccf00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -36,9 +36,8 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkException, SparkUnsupportedOperationException, TestUtils} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LocalRelation} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes @@ -1002,7 +1001,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } val stream = MemoryStream[Int] - val df = stream.toDF().select(new Column(Uuid())) + val df = stream.toDF().select(uuid()) testStream(df)( AddData(stream, 1), CheckAnswer(collectUuid), @@ -1022,7 +1021,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } val stream = MemoryStream[Int] - val df = stream.toDF().select(new Column(new Rand()), new Column(new Randn())) + val df = stream.toDF().select(rand(), randn()) testStream(df)( AddData(stream, 1), CheckAnswer(collectRand), @@ -1041,7 +1040,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } val stream = MemoryStream[Int] - val df = stream.toDF().select(new Column(new Shuffle(Literal.create[Seq[Int]](0 until 100)))) + val df = stream.toDF().select(shuffle(typedLit[Seq[Int]](0 until 100))) testStream(df)( AddData(stream, 1), CheckAnswer(collectShuffle), From e7a2a327b36484d0d02116b55ddf5a60dd13141c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 6 Aug 2024 22:23:17 -0400 Subject: [PATCH 02/18] Add internally registered functions --- .../scala/org/apache/spark/sql/internal/columnNodes.scala | 1 + .../sql/internal/ColumnNodeToExpressionConverter.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 19ce76bbc1ac7..bd827213d20a7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -88,6 +88,7 @@ private[sql] case class UnresolvedFunction( arguments: Seq[ColumnNode], isDistinct: Boolean = false, isUserDefinedFunction: Boolean = false, + isInternal: Boolean = false, override val origin: Origin = CurrentOrigin.get) extends ColumnNode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala index 8e3978ac16857..029bea398e089 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala @@ -66,13 +66,17 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case UnresolvedRegex(unparsedIdentifier, planId, _) => convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) - case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, _) => + case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, isInternal, _) => val nameParts = if (isUDF) { parser.parseMultipartIdentifier(functionName) } else { Seq(functionName) } - analysis.UnresolvedFunction(nameParts, arguments.map(apply), isDistinct) + analysis.UnresolvedFunction( + nameParts = nameParts, + arguments = arguments.map(apply), + isDistinct = isDistinct, + isInternal = isInternal) case Alias(child, Seq(name), metadata, _) => expressions.Alias(apply(child), name)(explicitMetadata = metadata) From a4e52f49d1f39d0c5b51b14dbd9e072baa6aa0a8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 6 Aug 2024 22:53:03 -0400 Subject: [PATCH 03/18] Move window to cool new API :) --- .../spark/sql/internal/columnNodes.scala | 4 +- .../apache/spark/sql/expressions/Window.scala | 2 +- .../spark/sql/expressions/WindowSpec.scala | 47 +++++++++---------- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index bd827213d20a7..0859ec74262b8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} -import org.apache.spark.sql.types.{DataType, Metadata} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata} /** * AST for constructing columns. This API is implementation agnostic and allows us to build a @@ -203,6 +203,8 @@ private[sql] object WindowFrame { object CurrentRow extends FrameBoundary object Unbounded extends FrameBoundary case class Value(value: ColumnNode) extends FrameBoundary + def value(i: Int): Value = Value(Literal(i, Some(IntegerType))) + def value(l: Long): Value = Value(Literal(l, Some(LongType))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 93bf738a53daf..2bf8a36d511fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -215,7 +215,7 @@ object Window { } private[sql] def spec: WindowSpec = { - new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + new WindowSpec(Seq.empty, Seq.empty, None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 7b4ef3ba9ecd6..ababf1ac15427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.{ColumnNode, SortOrder, Window => EvalWindow, WindowFrame, WindowSpec} /** * A window specification that defines the partitioning, ordering, and frame boundaries. @@ -31,9 +31,9 @@ import org.apache.spark.sql.errors.QueryCompilationErrors */ @Stable class WindowSpec private[sql]( - partitionSpec: Seq[Expression], + partitionSpec: Seq[ColumnNode], orderSpec: Seq[SortOrder], - frame: WindowFrame) { + frame: Option[WindowFrame]) { /** * Defines the partitioning columns in a [[WindowSpec]]. @@ -50,7 +50,7 @@ class WindowSpec private[sql]( */ @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { - new WindowSpec(cols.map(_.expr), orderSpec, frame) + new WindowSpec(cols.map(_.node), orderSpec, frame) } /** @@ -69,11 +69,9 @@ class WindowSpec private[sql]( @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) + col.node match { + case s: SortOrder => s + case _ => col.asc.node.asInstanceOf[SortOrder] } } new WindowSpec(partitionSpec, sortOrder, frame) @@ -125,23 +123,23 @@ class WindowSpec private[sql]( // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { val boundaryStart = start match { - case 0 => CurrentRow - case Long.MinValue => UnboundedPreceding - case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case 0 => WindowFrame.CurrentRow + case Long.MinValue => WindowFrame.Unbounded + case x if Int.MinValue <= x && x <= Int.MaxValue => WindowFrame.value(x.toInt) case x => throw QueryCompilationErrors.invalidBoundaryStartError(x) } val boundaryEnd = end match { - case 0 => CurrentRow - case Long.MaxValue => UnboundedFollowing - case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case 0 => WindowFrame.CurrentRow + case Long.MaxValue => WindowFrame.Unbounded + case x if Int.MinValue <= x && x <= Int.MaxValue => WindowFrame.value(x.toInt) case x => throw QueryCompilationErrors.invalidBoundaryEndError(x) } new WindowSpec( partitionSpec, orderSpec, - SpecifiedWindowFrame(RowFrame, boundaryStart, boundaryEnd)) + Some(WindowFrame(WindowFrame.FrameType.Row, boundaryStart, boundaryEnd))) } /** @@ -193,28 +191,27 @@ class WindowSpec private[sql]( // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { val boundaryStart = start match { - case 0 => CurrentRow - case Long.MinValue => UnboundedPreceding - case x => Literal(x) + case 0 => WindowFrame.CurrentRow + case Long.MinValue => WindowFrame.Unbounded + case x => WindowFrame.value(x) } val boundaryEnd = end match { - case 0 => CurrentRow - case Long.MaxValue => UnboundedFollowing - case x => Literal(x) + case 0 => WindowFrame.CurrentRow + case Long.MaxValue => WindowFrame.Unbounded + case x => WindowFrame.value(x) } new WindowSpec( partitionSpec, orderSpec, - SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) + Some(WindowFrame(WindowFrame.FrameType.Range, boundaryStart, boundaryEnd))) } /** * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) - Column(WindowExpression(aggregate.expr, spec)) + Column(EvalWindow(aggregate.node, WindowSpec(partitionSpec, orderSpec, frame))) } } From dcde4d467560edac7d0ca5d1a8e2bda32c5ca1b3 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 9 Aug 2024 09:43:09 -0400 Subject: [PATCH 04/18] Improve Window --- .../scala/org/apache/spark/sql/Column.scala | 5 +++ .../apache/spark/sql/expressions/Window.scala | 1 - .../spark/sql/expressions/WindowSpec.scala | 34 ++++++++---------- .../sql/DataFrameWindowFramesSuite.scala | 35 +++++++++---------- 4 files changed, 37 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d5bbb57786486..bb178a4a5d77f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1254,6 +1254,11 @@ class Column(val node: ColumnNode) extends Logging { Column(internal.SortOrder(node, sortDirection, nullOrdering)) } + private[sql] def sortOrder: internal.SortOrder = node match { + case order: internal.SortOrder => order + case _ => asc.node.asInstanceOf[internal.SortOrder] + } + /** * Returns a sort expression based on the descending order of the column. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 2bf8a36d511fe..9c4499ee243f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{WindowSpec => _, _} /** * Utility functions for defining window in DataFrames. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index ababf1ac15427..86318fa1704fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.{ColumnNode, SortOrder, Window => EvalWindow, WindowFrame, WindowSpec} +import org.apache.spark.sql.internal.{ColumnNode, SortOrder, Window => EvalWindow, WindowFrame, WindowSpec => InternalWindowSpec} /** * A window specification that defines the partitioning, ordering, and frame boundaries. @@ -68,13 +68,7 @@ class WindowSpec private[sql]( */ @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { - val sortOrder: Seq[SortOrder] = cols.map { col => - col.node match { - case s: SortOrder => s - case _ => col.asc.node.asInstanceOf[SortOrder] - } - } - new WindowSpec(partitionSpec, sortOrder, frame) + new WindowSpec(cols.map(_.node), cols.map(_.sortOrder), frame) } /** @@ -136,10 +130,7 @@ class WindowSpec private[sql]( case x => throw QueryCompilationErrors.invalidBoundaryEndError(x) } - new WindowSpec( - partitionSpec, - orderSpec, - Some(WindowFrame(WindowFrame.FrameType.Row, boundaryStart, boundaryEnd))) + withFrame(WindowFrame.Row, boundaryStart, boundaryEnd) } /** @@ -192,26 +183,31 @@ class WindowSpec private[sql]( def rangeBetween(start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => WindowFrame.CurrentRow - case Long.MinValue => WindowFrame.Unbounded + case Long.MinValue => WindowFrame.UnboundedPreceding case x => WindowFrame.value(x) } val boundaryEnd = end match { case 0 => WindowFrame.CurrentRow - case Long.MaxValue => WindowFrame.Unbounded + case Long.MaxValue => WindowFrame.UnboundedFollowing case x => WindowFrame.value(x) } + withFrame(WindowFrame.Range, boundaryStart, boundaryEnd) + } - new WindowSpec( - partitionSpec, - orderSpec, - Some(WindowFrame(WindowFrame.FrameType.Range, boundaryStart, boundaryEnd))) + private[sql] def withFrame( + frameType: WindowFrame.FrameType, + lower: WindowFrame.FrameBoundary, + uppper: WindowFrame.FrameBoundary): WindowSpec = { + val frame = WindowFrame(frameType, lower, uppper) + new WindowSpec(partitionSpec, orderSpec, Some(frame)) } /** * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - Column(EvalWindow(aggregate.node, WindowSpec(partitionSpec, orderSpec, frame))) + val spec = InternalWindowSpec(partitionSpec, sortColumns = orderSpec, frame = frame) + Column(EvalWindow(aggregate.node, spec)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 95f4cc78d1564..4923ef9c4ebb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnaryMinus, UnspecifiedFrame} +import org.apache.spark.sql.catalyst.expressions.{Literal, NonFoldableLiteral} import org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode} -import org.apache.spark.sql.expressions.{Window, WindowSpec} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, Wrapper} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.CalendarIntervalType @@ -503,11 +503,11 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("Window frame bounds lower and upper do not have the same type") { val df = Seq((1L, "1"), (1L, "1")).toDF("key", "value") - val windowSpec = new WindowSpec( - Seq(Column("value").expr), - Seq(SortOrder(Column("key").expr, Ascending)), - SpecifiedWindowFrame(RangeFrame, Literal.create(null, CalendarIntervalType), Literal(2)) - ) + + val windowSpec = Window.partitionBy($"value").orderBy($"key".asc).withFrame( + internal.WindowFrame.Range, + internal.WindowFrame.Value(Wrapper(Literal.create(null, CalendarIntervalType))), + internal.WindowFrame.Value(lit(2).node)) checkError( exception = intercept[AnalysisException] { df.select($"key", count("key").over(windowSpec)).collect() @@ -526,11 +526,10 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("Window frame lower bound is not a literal") { val df = Seq((1L, "1"), (1L, "1")).toDF("key", "value") - val windowSpec = new WindowSpec( - Seq(Column("value").expr), - Seq(SortOrder(Column("key").expr, Ascending)), - SpecifiedWindowFrame(RangeFrame, NonFoldableLiteral(1), Literal(2)) - ) + val windowSpec = Window.partitionBy($"value").orderBy($"key".asc).withFrame( + internal.WindowFrame.Range, + internal.WindowFrame.Value(Wrapper(NonFoldableLiteral(1))), + internal.WindowFrame.Value(lit(2).node)) checkError( exception = intercept[AnalysisException] { df.select($"key", count("key").over(windowSpec)).collect() @@ -546,8 +545,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("SPARK-41805: Reuse expressions in WindowSpecDefinition") { val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") - val sortOrder = SortOrder($"n".cast("string").expr, Ascending) - val window = new WindowSpec(Seq($"n".expr), Seq(sortOrder), UnspecifiedFrame) + val window = Window.partitionBy($"n").orderBy($"n".cast("string").asc) val df = ds.select(sum("i").over(window), avg("i").over(window)) val ws = df.queryExecution.analyzed.collect { case w: WindowNode => w } assert(ws.size === 1) @@ -557,9 +555,10 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("SPARK-41793: Incorrect result for window frames defined by a range clause on large " + "decimals") { - val window = new WindowSpec(Seq($"a".expr), Seq(SortOrder($"b".expr, Ascending)), - SpecifiedWindowFrame(RangeFrame, - UnaryMinus(Literal(BigDecimal(10.2345))), Literal(BigDecimal(6.7890)))) + val window = Window.partitionBy($"a").orderBy($"b".asc).withFrame( + internal.WindowFrame.Range, + internal.WindowFrame.Value((-lit(BigDecimal(10.2345))).node), + internal.WindowFrame.Value(lit(BigDecimal(10.2345)).node)) val df = Seq( 1 -> "11342371013783243717493546650944543.47", From 9acfdbeed6ad7b31cba774cc0c0b83ccd55ab39b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 9 Aug 2024 09:43:32 -0400 Subject: [PATCH 05/18] Refactor ColumnNode API --- .../spark/sql/internal/columnNodes.scala | 288 ++++++++++++++---- .../scala/org/apache/spark/sql/Column.scala | 51 ++-- .../analysis/DetectAmbiguousSelfJoin.scala | 14 +- .../spark/sql/expressions/WindowSpec.scala | 4 +- ...onverter.scala => columnNodeSupport.scala} | 124 ++++---- ...ColumnNodeToExpressionConverterSuite.scala | 83 ++--- 6 files changed, 366 insertions(+), 198 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/internal/{ColumnNodeToExpressionConverter.scala => columnNodeSupport.scala} (67%) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 0859ec74262b8..288cb37faa97b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.sql.internal -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import ColumnNode._ + import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata} +import org.apache.spark.util.SparkClassUtils /** * AST for constructing columns. This API is implementation agnostic and allows us to build a @@ -30,11 +32,59 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata} * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual connect * protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading issues. */ -private[sql] trait ColumnNode { +private[sql] trait ColumnNode extends ColumnNodeLike { /** * Origin where the node was created. */ def origin: Origin + + /** + * A normalized version of this node. This is stripped of dataset related (contextual) metadata. + * This is mostly used to power Column.equals and Column.hashcode. + */ + lazy val normalized: ColumnNode = { + val transformed = normalize() + if (this != transformed) { + transformed + } else { + this + } + } + + override private[internal] def normalize(): ColumnNode = this + + /** + * Return a SQL-a-like representation of the node. + * + * This is best effort; there are no guarantees that the returned SQL is valid. + */ + def sql: String +} + +trait ColumnNodeLike { + private[internal] def normalize(): ColumnNodeLike = this + private[internal] def sql: String +} + +private[internal] object ColumnNode { + val NO_ORIGIN: Origin = Origin() + def normalize[T <: ColumnNodeLike](option: Option[T]): Option[T] = + option.map(_.normalize().asInstanceOf[T]) + def normalize[T <: ColumnNodeLike](nodes: Seq[T]): Seq[T] = + nodes.map(_.normalize().asInstanceOf[T]) + def argumentsToSql(nodes: Seq[ColumnNodeLike]): String = + textArgumentsToSql(nodes.map(_.sql)) + def textArgumentsToSql(parts: Seq[String]): String = parts.mkString("(", ",", ")") + def elementsToSql(elements: Seq[ColumnNodeLike], prefix: String = ""): String = { + if (elements.nonEmpty) { + elements.map(_.sql).mkString(prefix, ",", "") + } else { + "" + } + } + def optionToSql(option: Option[ColumnNodeLike]): String = { + option.map(_.sql).getOrElse("") + } } /** @@ -46,7 +96,15 @@ private[sql] trait ColumnNode { private[sql] case class Literal( value: Any, dataType: Option[DataType] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): Literal = copy(origin = NO_ORIGIN) + + // TODO make this nicer. + override def sql: String = value match { + case null => "NULL" + case _ => value.toString + } +} /** * Reference to an attribute produced by one of the underlying DataFrames. @@ -60,7 +118,11 @@ private[sql] case class UnresolvedAttribute( planId: Option[Long] = None, isMetadataColumn: Boolean = false, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + override private[internal] def normalize(): UnresolvedAttribute = + copy(planId = None, origin = NO_ORIGIN) + override def sql: String = unparsedIdentifier +} /** * Reference to all columns in a namespace (global, a Dataframe, or a nested struct). @@ -72,7 +134,11 @@ private[sql] case class UnresolvedStar( unparsedTarget: Option[String], planId: Option[Long] = None, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + override private[internal] def normalize(): UnresolvedStar = + copy(planId = None, origin = NO_ORIGIN) + override def sql: String = unparsedTarget.map(_ + ".*").getOrElse("*") +} /** * Call a function. This can either be a built-in function, a UDF, or a UDF registered in the @@ -90,7 +156,12 @@ private[sql] case class UnresolvedFunction( isUserDefinedFunction: Boolean = false, isInternal: Boolean = false, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + override private[internal] def normalize(): UnresolvedFunction = + copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN) + + override def sql: String = functionName + argumentsToSql(arguments) +} /** * Evaluate a SQL expression. @@ -99,7 +170,10 @@ private[sql] case class UnresolvedFunction( */ private[sql] case class SqlExpression( expression: String, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): SqlExpression = copy(origin = NO_ORIGIN) + override def sql: String = expression +} /** * Name a column, and (optionally) modify its metadata. @@ -112,7 +186,18 @@ private[sql] case class Alias( child: ColumnNode, name: Seq[String], metadata: Option[Metadata] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): Alias = + copy(child = child.normalize(), origin = NO_ORIGIN) + + override def sql: String = { + val alias = name match { + case Seq(single) => single + case multiple => textArgumentsToSql(multiple) + } + s"${child.sql} AS $alias" + } +} /** * Cast the value of a Column to a different [[DataType]]. The behavior of the cast can be @@ -125,16 +210,23 @@ private[sql] case class Alias( private[sql] case class Cast( child: ColumnNode, dataType: DataType, - evalMode: Option[Cast.EvalMode.Value] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + evalMode: Option[Cast.EvalMode] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): Cast = + copy(child = child.normalize(), origin = NO_ORIGIN) -private[sql] object Cast { - object EvalMode extends Enumeration { - type EvalMode = Value - val Legacy, Ansi, Try = Value + override def sql: String = { + s"${optionToSql(evalMode)}CAST(${child.sql} AS ${dataType.sql})" } } +private[sql] object Cast { + sealed abstract class EvalMode(override val sql: String = "") extends ColumnNodeLike + object Legacy extends EvalMode + object Ansi extends EvalMode + object Try extends EvalMode("TRY_") +} + /** * Reference to all columns in the global namespace in that match a regex. * @@ -144,7 +236,11 @@ private[sql] object Cast { private[sql] case class UnresolvedRegex( regex: String, planId: Option[Long] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): UnresolvedRegex = + copy(planId = None, origin = NO_ORIGIN) + override def sql: String = regex +} /** * Sort the input column. @@ -155,20 +251,23 @@ private[sql] case class UnresolvedRegex( */ private[sql] case class SortOrder( child: ColumnNode, - sortDirection: SortOrder.SortDirection.Value, - nullOrdering: SortOrder.NullOrdering.Value, + sortDirection: SortOrder.SortDirection, + nullOrdering: SortOrder.NullOrdering, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + override private[internal] def normalize(): SortOrder = + copy(child = child.normalize(), origin = NO_ORIGIN) + + override def sql: String = s"${child.sql} ${sortDirection.sql} ${nullOrdering.sql}" +} private[sql] object SortOrder { - object SortDirection extends Enumeration { - type SortDirection = Value - val Ascending, Descending = Value - } - object NullOrdering extends Enumeration { - type NullOrdering = Value - val NullsFirst, NullsLast = Value - } + sealed abstract class SortDirection(override val sql: String) extends ColumnNodeLike + object Ascending extends SortDirection("ASCENDING") + object Descending extends SortDirection("DESCENDING") + sealed abstract class NullOrdering(override val sql: String) extends ColumnNodeLike + object NullsFirst extends NullOrdering("NULLS FIRST") + object NullsLast extends NullOrdering("NULLS LAST") } /** @@ -181,28 +280,64 @@ private[sql] case class Window( windowFunction: ColumnNode, windowSpec: WindowSpec, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + override private[internal] def normalize(): Window = copy( + windowFunction = windowFunction.normalize(), + windowSpec = windowSpec.normalize(), + origin = NO_ORIGIN) + + override def sql: String = s"${windowFunction.sql} OVER (${windowSpec.sql})" +} private[sql] case class WindowSpec( partitionColumns: Seq[ColumnNode], sortColumns: Seq[SortOrder], - frame: Option[WindowFrame] = None) + frame: Option[WindowFrame] = None) extends ColumnNodeLike { + override private[internal] def normalize(): WindowSpec = copy( + partitionColumns = ColumnNode.normalize(partitionColumns), + sortColumns = ColumnNode.normalize(sortColumns), + frame = ColumnNode.normalize(frame)) + override private[internal] def sql: String = { + val parts = Seq( + elementsToSql(partitionColumns, "PARTITION BY "), + elementsToSql(sortColumns, "ORDER BY"), + optionToSql(frame)) + parts.filter(_.nonEmpty).mkString(" ") + } +} private[sql] case class WindowFrame( - frameType: WindowFrame.FrameType.Value, + frameType: WindowFrame.FrameType, lower: WindowFrame.FrameBoundary, upper: WindowFrame.FrameBoundary) + extends ColumnNodeLike { + override private[internal] def normalize(): WindowFrame = + copy(lower = lower.normalize(), upper = upper.normalize()) + override private[internal] def sql: String = + s"${frameType.sql} BETWEEN ${lower.sql} AND ${upper.sql}" +} private[sql] object WindowFrame { - object FrameType extends Enumeration { - type FrameType = this.Value - val Row, Range = this.Value - } + sealed abstract class FrameType(override val sql: String) extends ColumnNodeLike + object Row extends FrameType("ROWS") + object Range extends FrameType("RANGE") - sealed trait FrameBoundary - object CurrentRow extends FrameBoundary - object Unbounded extends FrameBoundary - case class Value(value: ColumnNode) extends FrameBoundary + sealed abstract class FrameBoundary extends ColumnNodeLike { + override private[internal] def normalize(): FrameBoundary = this + } + object CurrentRow extends FrameBoundary { + override private[internal] def sql = "CURRENT ROW" + } + object UnboundedPreceding extends FrameBoundary { + override private[internal] def sql = "UNBOUNDED PRECEDING" + } + object UnboundedFollowing extends FrameBoundary { + override private[internal] def sql = "UNBOUNDED FOLLOWING" + } + case class Value(value: ColumnNode) extends FrameBoundary { + override private[internal] def normalize(): Value = copy(value.normalize()) + override private[internal] def sql: String = value.sql + } def value(i: Int): Value = Value(Literal(i, Some(IntegerType))) def value(l: Long): Value = Value(Literal(l, Some(LongType))) } @@ -216,7 +351,14 @@ private[sql] object WindowFrame { private[sql] case class LambdaFunction( function: ColumnNode, arguments: Seq[UnresolvedNamedLambdaVariable], - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): LambdaFunction = copy( + function = function.normalize(), + arguments = ColumnNode.normalize(arguments), + origin = NO_ORIGIN) + + override def sql: String = argumentsToSql(arguments) + " -> " + function.sql +} /** * Variable used in a [[LambdaFunction]]. @@ -225,7 +367,12 @@ private[sql] case class LambdaFunction( */ private[sql] case class UnresolvedNamedLambdaVariable( name: String, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): UnresolvedNamedLambdaVariable = + copy(origin = NO_ORIGIN) + + override def sql: String = name +} /** * Extract a value from a complex type. This can be a field from a struct, a value from a map, @@ -238,7 +385,14 @@ private[sql] case class UnresolvedNamedLambdaVariable( private[sql] case class UnresolvedExtractValue( child: ColumnNode, extraction: ColumnNode, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): UnresolvedExtractValue = copy( + child = child.normalize(), + extraction = child.normalize(), + origin = NO_ORIGIN) + + override def sql: String = s"${child.sql}[${extraction.sql}]" +} /** * Update or drop the field of a struct. @@ -251,7 +405,16 @@ private[sql] case class UpdateFields( structExpression: ColumnNode, fieldName: String, valueExpression: Option[ColumnNode] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): UpdateFields = copy( + structExpression = structExpression.normalize(), + valueExpression = ColumnNode.normalize(valueExpression), + origin = NO_ORIGIN) + override def sql: String = valueExpression match { + case Some(value) => s"update_field(${structExpression.sql}, $fieldName, ${value.sql})" + case None => s"drop_field(${structExpression.sql}, $fieldName)" + } +} /** * Evaluate one or more conditional branches. The value of the first branch for which the predicate @@ -265,7 +428,19 @@ private[sql] case class CaseWhenOtherwise( branches: Seq[(ColumnNode, ColumnNode)], otherwise: Option[ColumnNode] = None, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + assert(branches.nonEmpty) + override private[internal] def normalize(): CaseWhenOtherwise = copy( + branches = branches.map(kv => (kv._1.normalize(), kv._2.normalize())), + otherwise = ColumnNode.normalize(otherwise), + origin = NO_ORIGIN) + + override def sql: String = + "CASE " + + branches.map(cv => s"WHEN ${cv._1.sql} THEN ${cv._2.sql}").mkString(" ") + + otherwise.map(o => s"ELSE ${o.sql}") + + " END" +} /** * Invoke an inline user defined function. @@ -274,25 +449,18 @@ private[sql] case class CaseWhenOtherwise( * @param arguments to pass into the user defined function. */ private[sql] case class InvokeInlineUserDefinedFunction( - function: UserDefinedFunction, + function: UserDefinedFunctionLike, arguments: Seq[ColumnNode], + isDistinct: Boolean = false, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode + extends ColumnNode { + override private[internal] def normalize(): InvokeInlineUserDefinedFunction = + copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN) -// This is a temporary class until we move the actual interfaces -private[sql] case class UserDefinedFunction( - function: AnyRef, - resultEncoder: AgnosticEncoder[Any], - inputEncoders: Seq[AgnosticEncoder[Any]], - name: Option[String], - nonNullable: Boolean, - deterministic: Boolean) + override def sql: String = + function.name + argumentsToSql(arguments) +} -/** - * Extension point that allows an implementation to use its column representation to be used in a - * generic column expression. This should only be used when the Column constructed is used within - * the implementation. - */ -private[sql] case class Extension( - value: Any, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode +private[sql] trait UserDefinedFunctionLike { + def name: String = SparkClassUtils.getFormattedClassName(this) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index bb178a4a5d77f..72f5d0a2ba25f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{ColumnNode, Extension, TypedAggUtils} +import org.apache.spark.sql.internal.{ColumnNode, TypedAggUtils, Wrapper} import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -39,14 +39,10 @@ private[spark] object Column { def apply(colName: String): Column = new Column(colName) - // TODO move this to a separate class! - // Move as much as we can to the new API - // Create internal util for create expression(nodes) - def apply(expr: Expression): Column = Column(Extension(expr)) + def apply(expr: Expression): Column = Column(Wrapper(expr)) def apply(node: => ColumnNode): Column = withOrigin(new Column(node)) - // TODO move this else where... private[sql] def generateAlias(e: Expression): String = { e match { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => @@ -55,16 +51,6 @@ private[spark] object Column { } } - // TODO move this else where... - private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { - val metadataWithoutId = new MetadataBuilder() - .withMetadata(a.metadata) - .remove(Dataset.DATASET_ID_KEY) - .remove(Dataset.COL_POS_KEY) - .build() - a.withMetadata(metadataWithoutId) - } - private[sql] def fn(name: String, inputs: Column*): Column = { fn(name, isDistinct = false, inputs: _*) } @@ -163,8 +149,6 @@ class TypedColumn[-T, U]( */ @Stable class Column(val node: ColumnNode) extends Logging { - // TODO this will be moved to the calling classes. - // We must, must, must move all user facing use cases. lazy val expr: Expression = internal.ColumnNodeToExpressionConverter(node) def this(name: String) = this(withOrigin { @@ -185,19 +169,18 @@ class Column(val node: ColumnNode) extends Logging { Column.fn(name, this, lit(other)) } - override def toString: String = toPrettySQL(expr) + override def toString: String = node.sql override def equals(that: Any): Boolean = that match { - case that: Column => that.node == this.node + case that: Column => that.node.normalized == this.node.normalized case _ => false } - override def hashCode: Int = this.node.hashCode() + override def hashCode: Int = this.node.normalized.hashCode() /** * Returns the expression for this column either with an existing or auto assigned name. */ - // TODO move this elsewhere. private[sql] def named: NamedExpression = expr match { case expr: NamedExpression => expr @@ -1231,7 +1214,7 @@ class Column(val node: ColumnNode) extends Logging { * @since 4.0.0 */ def try_cast(to: DataType): Column = { - Column(internal.Cast(node, to, Option(internal.Cast.EvalMode.Try))) + Column(internal.Cast(node, to, Option(internal.Cast.Try))) } /** @@ -1249,8 +1232,8 @@ class Column(val node: ColumnNode) extends Logging { } private def sortOrder( - sortDirection: internal.SortOrder.SortDirection.Value, - nullOrdering: internal.SortOrder.NullOrdering.Value): Column = { + sortDirection: internal.SortOrder.SortDirection, + nullOrdering: internal.SortOrder.NullOrdering): Column = { Column(internal.SortOrder(node, sortDirection, nullOrdering)) } @@ -1289,8 +1272,8 @@ class Column(val node: ColumnNode) extends Logging { * @since 2.1.0 */ def desc_nulls_first: Column = sortOrder( - internal.SortOrder.SortDirection.Descending, - internal.SortOrder.NullOrdering.NullsFirst) + internal.SortOrder.Descending, + internal.SortOrder.NullsFirst) /** * Returns a sort expression based on the descending order of the column, @@ -1307,8 +1290,8 @@ class Column(val node: ColumnNode) extends Logging { * @since 2.1.0 */ def desc_nulls_last: Column = sortOrder( - internal.SortOrder.SortDirection.Descending, - internal.SortOrder.NullOrdering.NullsLast) + internal.SortOrder.Descending, + internal.SortOrder.NullsLast) /** * Returns a sort expression based on ascending order of the column. @@ -1340,8 +1323,8 @@ class Column(val node: ColumnNode) extends Logging { * @since 2.1.0 */ def asc_nulls_first: Column = sortOrder( - internal.SortOrder.SortDirection.Ascending, - internal.SortOrder.NullOrdering.NullsFirst) + internal.SortOrder.Ascending, + internal.SortOrder.NullsFirst) /** * Returns a sort expression based on ascending order of the column, @@ -1358,8 +1341,8 @@ class Column(val node: ColumnNode) extends Logging { * @since 2.1.0 */ def asc_nulls_last: Column = sortOrder( - internal.SortOrder.SortDirection.Ascending, - internal.SortOrder.NullOrdering.NullsLast) + internal.SortOrder.Ascending, + internal.SortOrder.NullsLast) /** * Prints the expression to the console for debugging purposes. @@ -1372,7 +1355,7 @@ class Column(val node: ColumnNode) extends Logging { if (extended) { println(node) } else { - println(expr.sql) // TODO (need to add this to the nodes)! + println(node.sql) } // scalastyle:on println } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala index c2925a3ba596b..25c8f695689c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.analysis import scala.collection.mutable import org.apache.spark.SparkException -import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Cast, Equality, Expression, ExprId} import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.MetadataBuilder /** * Detects ambiguous self-joins, so that we can fail the query instead of returning confusing @@ -169,7 +170,16 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] { plan.transformExpressions { case a: AttributeReference if isColumnReference(a) => // Remove the special metadata from this `AttributeReference`, as the detection is done. - Column.stripColumnReferenceMetadata(a) + stripColumnReferenceMetadata(a) } } + + private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { + val metadataWithoutId = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(Dataset.DATASET_ID_KEY) + .remove(Dataset.COL_POS_KEY) + .build() + a.withMetadata(metadataWithoutId) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 86318fa1704fd..c8d7f538df43c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -118,14 +118,14 @@ class WindowSpec private[sql]( def rowsBetween(start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => WindowFrame.CurrentRow - case Long.MinValue => WindowFrame.Unbounded + case Long.MinValue => WindowFrame.UnboundedPreceding case x if Int.MinValue <= x && x <= Int.MaxValue => WindowFrame.value(x.toInt) case x => throw QueryCompilationErrors.invalidBoundaryStartError(x) } val boundaryEnd = end match { case 0 => WindowFrame.CurrentRow - case Long.MaxValue => WindowFrame.Unbounded + case Long.MaxValue => WindowFrame.UnboundedFollowing case x if Int.MinValue <= x && x <= Int.MaxValue => WindowFrame.value(x.toInt) case x => throw QueryCompilationErrors.invalidBoundaryEndError(x) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala similarity index 67% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 029bea398e089..93f89679ddac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundEncoder -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression} -import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, TypedAggregateExpression} +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator} +import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF /** * Convert a [[ColumnNode]] into an [[Expression]]. @@ -86,9 +86,9 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case Cast(child, dataType, evalMode, _) => val convertedEvalMode = evalMode match { - case Some(Cast.EvalMode.Ansi) => expressions.EvalMode.ANSI - case Some(Cast.EvalMode.Legacy) => expressions.EvalMode.LEGACY - case Some(Cast.EvalMode.Try) => expressions.EvalMode.TRY + case Some(Cast.Ansi) => expressions.EvalMode.ANSI + case Some(Cast.Legacy) => expressions.EvalMode.LEGACY + case Some(Cast.Try) => expressions.EvalMode.TRY case _ => expressions.EvalMode.fromSQLConf(conf) } val cast = expressions.Cast( @@ -109,20 +109,13 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres val frame = spec.frame match { case Some(WindowFrame(frameType, lower, upper)) => val convertedFrameType = frameType match { - case WindowFrame.FrameType.Range => expressions.RangeFrame - case WindowFrame.FrameType.Row => expressions.RowFrame + case WindowFrame.Range => expressions.RangeFrame + case WindowFrame.Row => expressions.RowFrame } - val convertedLower = lower match { - case WindowFrame.CurrentRow => expressions.CurrentRow - case WindowFrame.Unbounded => expressions.UnboundedPreceding - case WindowFrame.Value(node) => apply(node) - } - val convertedUpper = upper match { - case WindowFrame.CurrentRow => expressions.CurrentRow - case WindowFrame.Unbounded => expressions.UnboundedFollowing - case WindowFrame.Value(node) => apply(node) - } - expressions.SpecifiedWindowFrame(convertedFrameType, convertedLower, convertedUpper) + expressions.SpecifiedWindowFrame( + convertedFrameType, + convertWindowFrameBoundary(lower), + convertWindowFrameBoundary(upper)) case None => expressions.UnspecifiedFrame } @@ -157,40 +150,24 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres }, elseValue = otherwise.map(apply)) - case InvokeInlineUserDefinedFunction(f, arguments, _) => - // This code is a bit clunky, it will stay this way until we have moved everything to - // sql/api and we can actually use the SparkUserDefinedFunction and - // UserDefinedAggregator classes. - (f.function, arguments.map(apply)) match { - case (a: Aggregator[Any @unchecked, Any @unchecked, Any @unchecked], Nil) => - TypedAggregateExpression(a)(a.bufferEncoder, a.outputEncoder).toAggregateExpression() - - case (a: Aggregator[Any @unchecked, Any @unchecked, Any @unchecked], children) => - ScalaAggregator( - agg = a, - children = children, - inputEncoder = ExpressionEncoder(f.inputEncoders.head), - bufferEncoder = ExpressionEncoder(f.resultEncoder), - aggregatorName = f.name, - nullable = !f.nonNullable && f.resultEncoder.nullable, - isDeterministic = f.deterministic).toAggregateExpression() - - case (function, children) => - ScalaUDF( - function = function, - dataType = f.resultEncoder.dataType, - children = children, - inputEncoders = f.inputEncoders.map { - case UnboundEncoder => None - case encoder => Option(ExpressionEncoder(encoder)) - }, - outputEncoder = Option(ExpressionEncoder(f.resultEncoder)), - udfName = f.name, - nullable = !f.nonNullable && f.resultEncoder.nullable, - udfDeterministic = f.deterministic) - } + case InvokeInlineUserDefinedFunction( + a: Aggregator[Any @unchecked, Any @unchecked, Any @unchecked], Nil, isDistinct, _) => + TypedAggregateExpression(a)(a.bufferEncoder, a.outputEncoder) + .toAggregateExpression(isDistinct) + + case InvokeInlineUserDefinedFunction( + a: UserDefinedAggregator[Any @unchecked, Any @unchecked, Any @unchecked], + arguments, isDistinct, _) => + ScalaAggregator(a, arguments.map(apply)).toAggregateExpression(isDistinct) + + case InvokeInlineUserDefinedFunction( + a: UserDefinedAggregateFunction, arguments, isDistinct, _) => + ScalaUDAF(udaf = a, children = arguments.map(apply)).toAggregateExpression(isDistinct) - case Extension(expression: Expression, _) => + case InvokeInlineUserDefinedFunction(udf: SparkUserDefinedFunction, arguments, _, _) => + toScalaUDF(udf, arguments.map(apply)) + + case Wrapper(expression, _) => expression case node => @@ -205,16 +182,25 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres private def convertSortOrder(sortOrder: SortOrder): expressions.SortOrder = { val sortDirection = sortOrder.sortDirection match { - case SortOrder.SortDirection.Ascending => expressions.Ascending - case SortOrder.SortDirection.Descending => expressions.Descending + case SortOrder.Ascending => expressions.Ascending + case SortOrder.Descending => expressions.Descending } val nullOrdering = sortOrder.nullOrdering match { - case SortOrder.NullOrdering.NullsFirst => expressions.NullsFirst - case SortOrder.NullOrdering.NullsLast => expressions.NullsLast + case SortOrder.NullsFirst => expressions.NullsFirst + case SortOrder.NullsLast => expressions.NullsLast } expressions.SortOrder(apply(sortOrder.child), sortDirection, nullOrdering, Nil) } + private def convertWindowFrameBoundary(boundary: WindowFrame.FrameBoundary): Expression = { + boundary match { + case WindowFrame.CurrentRow => expressions.CurrentRow + case WindowFrame.UnboundedPreceding => expressions.UnboundedPreceding + case WindowFrame.UnboundedFollowing => expressions.UnboundedFollowing + case WindowFrame.Value(node) => apply(node) + } + } + private def convertUnresolvedAttribute( unparsedIdentifier: String, planId: Option[Long], @@ -230,7 +216,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } } -object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { +private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { override protected def parser: ParserInterface = { SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { new SparkSqlParser() @@ -239,3 +225,21 @@ object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { override protected def conf: SQLConf = SQLConf.get } + + +/** + * [[ColumnNode]] wrapper for an [[Expression]]. + */ +private[sql] case class Wrapper( + expression: Expression, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override def normalize(): Wrapper = { + val updated = expression.transform { + case a: AttributeReference => + DetectAmbiguousSelfJoin.stripColumnReferenceMetadata(a) + } + copy(updated, ColumnNode.NO_ORIGIN) + } + + override def sql: String = expression.sql +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index 708b8b24b4df0..50cb7f437adb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator} import org.apache.spark.sql.types._ /** @@ -131,7 +131,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { private def testCast( dataType: DataType, - colEvalMode: Cast.EvalMode.Value, + colEvalMode: Cast.EvalMode, catEvalMode: expressions.EvalMode.Value): Unit = { testConversion( Cast(UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), @@ -143,14 +143,14 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { Cast(UnresolvedAttribute("str"), DoubleType), expressions.Cast(analysis.UnresolvedAttribute("str"), DoubleType)) - testCast(LongType, Cast.EvalMode.Legacy, expressions.EvalMode.LEGACY) - testCast(BinaryType, Cast.EvalMode.Try, expressions.EvalMode.TRY) - testCast(ShortType, Cast.EvalMode.Ansi, expressions.EvalMode.ANSI) + testCast(LongType, Cast.Legacy, expressions.EvalMode.LEGACY) + testCast(BinaryType, Cast.Try, expressions.EvalMode.TRY) + testCast(ShortType, Cast.Ansi, expressions.EvalMode.ANSI) } private def testSortOrder( - colDirection: SortOrder.SortDirection.SortDirection, - colNullOrdering: SortOrder.NullOrdering.NullOrdering, + colDirection: SortOrder.SortDirection, + colNullOrdering: SortOrder.NullOrdering, catDirection: expressions.SortDirection, catNullOrdering: expressions.NullOrdering): Unit = { testConversion( @@ -164,29 +164,29 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("sortOrder") { testSortOrder( - SortOrder.SortDirection.Ascending, - SortOrder.NullOrdering.NullsFirst, + SortOrder.Ascending, + SortOrder.NullsFirst, expressions.Ascending, expressions.NullsFirst) testSortOrder( - SortOrder.SortDirection.Ascending, - SortOrder.NullOrdering.NullsLast, + SortOrder.Ascending, + SortOrder.NullsLast, expressions.Ascending, expressions.NullsLast) testSortOrder( - SortOrder.SortDirection.Descending, - SortOrder.NullOrdering.NullsFirst, + SortOrder.Descending, + SortOrder.NullsFirst, expressions.Descending, expressions.NullsFirst) testSortOrder( - SortOrder.SortDirection.Descending, - SortOrder.NullOrdering.NullsLast, + SortOrder.Descending, + SortOrder.NullsLast, expressions.Descending, expressions.NullsLast) } private def testWindowFrame( - colFrameType: WindowFrame.FrameType.FrameType, + colFrameType: WindowFrame.FrameType, colLower: WindowFrame.FrameBoundary, colUpper: WindowFrame.FrameBoundary, catFrameType: expressions.FrameType, @@ -199,8 +199,8 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), Seq(SortOrder( UnresolvedAttribute("d"), - SortOrder.SortDirection.Descending, - SortOrder.NullOrdering.NullsLast)), + SortOrder.Descending, + SortOrder.NullsLast)), Option(WindowFrame(colFrameType, colLower, colUpper)))), expressions.WindowExpression( analysis.UnresolvedFunction( @@ -235,15 +235,15 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { Nil, expressions.UnspecifiedFrame))) testWindowFrame( - WindowFrame.FrameType.Row, + WindowFrame.Row, WindowFrame.Value(Literal(-10)), - WindowFrame.Unbounded, + WindowFrame.UnboundedFollowing, expressions.RowFrame, expressions.Literal(-10), expressions.UnboundedFollowing) testWindowFrame( - WindowFrame.FrameType.Range, - WindowFrame.Unbounded, + WindowFrame.Range, + WindowFrame.UnboundedPreceding, WindowFrame.CurrentRow, expressions.RangeFrame, expressions.UnboundedPreceding, @@ -301,14 +301,6 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("udf") { val int2LongSum = new aggregate.TypedSumLong[Int]((i: Int) => i.toLong) - val aggregator = UserDefinedFunction( - int2LongSum, - toAny(AgnosticEncoders.PrimitiveLongEncoder), - toAny(AgnosticEncoders.PrimitiveIntEncoder) :: Nil, - name = Option("int2LongSum"), - nonNullable = false, - deterministic = true) - val bufferEncoder = encoderFor(int2LongSum.bufferEncoder) val outputEncoder = encoderFor(int2LongSum.outputEncoder) val bufferAttrs = bufferEncoder.namedExpressions.map { @@ -317,7 +309,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { // Aggregator applied on the entire Dataset. testConversion( - InvokeInlineUserDefinedFunction(aggregator, Nil), + InvokeInlineUserDefinedFunction(int2LongSum, Nil), aggregate.SimpleTypedAggregateExpression( aggregator = int2LongSum.asInstanceOf[Aggregator[Any, Any, Any]], inputDeserializer = None, @@ -336,7 +328,12 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { // Aggregator applied on an input. testConversion( - InvokeInlineUserDefinedFunction(aggregator, UnresolvedAttribute("i_col") :: Nil), + InvokeInlineUserDefinedFunction( + UserDefinedAggregator( + aggregator = int2LongSum, + inputEncoder = AgnosticEncoders.PrimitiveIntEncoder, + nullable = false), + UnresolvedAttribute("i_col") :: Nil), aggregate.ScalaAggregator( children = analysis.UnresolvedAttribute("i_col") :: Nil, agg = int2LongSum, @@ -349,12 +346,12 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { val concat = (a: String, b: String) => a + b testConversion( InvokeInlineUserDefinedFunction( - UserDefinedFunction( - function = concat, - resultEncoder = toAny(AgnosticEncoders.StringEncoder), - AgnosticEncoders.UnboundEncoder :: toAny(AgnosticEncoders.StringEncoder) :: Nil, - name = None, - nonNullable = true, + SparkUserDefinedFunction( + f = concat, + inputEncoders = None :: Option(toAny(AgnosticEncoders.StringEncoder)) :: Nil, + outputEncoder = Option(toAny(AgnosticEncoders.StringEncoder)), + dataType = StringType, + nullable = false, deterministic = false), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"))), expressions.ScalaUDF( @@ -370,11 +367,17 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("extension") { testConversion( - Extension(analysis.UnresolvedAttribute("bar")), + Wrapper(analysis.UnresolvedAttribute("bar")), analysis.UnresolvedAttribute("bar")) } test("unsupported") { - intercept[SparkException](Converter(Extension("kaboom"))) + intercept[SparkException](Converter(Nope())) } } + +private[internal] case class Nope(override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override private[internal] def normalize(): Nope = this + override def sql: String = "nope" +} From 6e31176f04da9b91a60e2b90d4d19e1ea9f431b8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 9 Aug 2024 09:43:46 -0400 Subject: [PATCH 06/18] Support UDFs/UDAFs --- .../apache/spark/util/SparkClassUtils.scala | 67 ++ .../scala/org/apache/spark/util/Utils.scala | 68 -- .../connect/planner/SparkConnectPlanner.scala | 24 +- .../scala/org/apache/spark/sql/Column.scala | 18 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 619 ++++-------------- .../spark/sql/execution/aggregate/udaf.scala | 19 +- .../spark/sql/expressions/Aggregator.scala | 14 +- .../sql/expressions/UserDefinedFunction.scala | 108 +-- .../apache/spark/sql/expressions/udaf.scala | 11 +- .../org/apache/spark/sql/functions.scala | 106 +-- .../spark/sql/IntegratedUDFTestUtils.scala | 5 +- 13 files changed, 326 insertions(+), 739 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala index 7a4ef4a5ce81f..307006315a3c4 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala @@ -69,6 +69,73 @@ private[spark] trait SparkClassUtils { targetClass == null || targetClass.isAssignableFrom(cls) }.getOrElse(false) } + + /** Return the class name of the given object, removing all dollar signs */ + def getFormattedClassName(obj: AnyRef): String = { + getSimpleName(obj.getClass).replace("$", "") + } + + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimics scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + cls.getSimpleName + } catch { + // TODO: the value returned here isn't even quite right; it returns simple names + // like UtilsSuite$MalformedClassObject$MalformedClass instead of MalformedClass + // The exact value may not matter much as it's used in log statements + case _: InternalError => + stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + @scala.annotation.tailrec + final def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an interpreter + // generated class, so we should return the full string + s + } else { + // The class name is interpreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.findLast(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[spark] object SparkClassUtils extends SparkClassUtils diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a37aedfcb635a..ff541fa5004b3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1753,12 +1753,6 @@ private[spark] object Utils Files.createSymbolicLink(dst.toPath, src.toPath) } - - /** Return the class name of the given object, removing all dollar signs */ - def getFormattedClassName(obj: AnyRef): String = { - getSimpleName(obj.getClass).replace("$", "") - } - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -2814,68 +2808,6 @@ private[spark] object Utils Hex.encodeHexString(secretBytes) } - /** - * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. - * This method mimics scalatest's getSimpleNameOfAnObjectsClass. - */ - def getSimpleName(cls: Class[_]): String = { - try { - cls.getSimpleName - } catch { - // TODO: the value returned here isn't even quite right; it returns simple names - // like UtilsSuite$MalformedClassObject$MalformedClass instead of MalformedClass - // The exact value may not matter much as it's used in log statements - case _: InternalError => - stripDollars(stripPackages(cls.getName)) - } - } - - /** - * Remove the packages from full qualified class name - */ - private def stripPackages(fullyQualifiedName: String): String = { - fullyQualifiedName.split("\\.").takeRight(1)(0) - } - - /** - * Remove trailing dollar signs from qualified class name, - * and return the trailing part after the last dollar sign in the middle - */ - @scala.annotation.tailrec - def stripDollars(s: String): String = { - val lastDollarIndex = s.lastIndexOf('$') - if (lastDollarIndex < s.length - 1) { - // The last char is not a dollar sign - if (lastDollarIndex == -1 || !s.contains("$iw")) { - // The name does not have dollar sign or is not an interpreter - // generated class, so we should return the full string - s - } else { - // The class name is interpreter generated, - // return the part after the last dollar sign - // This is the same behavior as getClass.getSimpleName - s.substring(lastDollarIndex + 1) - } - } - else { - // The last char is a dollar sign - // Find last non-dollar char - val lastNonDollarChar = s.findLast(_ != '$') - lastNonDollarChar match { - case None => s - case Some(c) => - val lastNonDollarIndex = s.lastIndexOf(c) - if (lastNonDollarIndex == -1) { - s - } else { - // Strip the trailing dollar signs - // Invoke stripDollars again to get the simple name - stripDollars(s.substring(0, lastNonDollarIndex + 1)) - } - } - } - } - /** * Regular expression matching full width characters. * diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 43b300a11a49d..6b9136cf18a35 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -67,7 +67,7 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, Spark import org.apache.spark.sql.connect.utils.MetricGenerator import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -77,7 +77,7 @@ import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPy import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction, UserDefinedFunctionUtils} import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} @@ -1718,9 +1718,9 @@ class SparkConnectPlanner( val udf = fun.getScalarScalaUdf val udfPacket = unpackUdf(fun) if (udf.getAggregate) { - transformScalaFunction(fun) - .asInstanceOf[UserDefinedAggregator[Any, Any, Any]] - .scalaAggregator(fun.getArgumentsList.asScala.map(transformExpression).toSeq) + ScalaAggregator( + transformScalaFunction(fun).asInstanceOf[UserDefinedAggregator[Any, Any, Any]], + fun.getArgumentsList.asScala.map(transformExpression).toSeq) .toAggregateExpression() } else { ScalaUDF( @@ -1744,7 +1744,7 @@ class SparkConnectPlanner( UserDefinedAggregator( aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, Any]], inputEncoder = ExpressionEncoder(udfPacket.inputEncoders.head), - name = Option(fun.getFunctionName), + givenName = Option(fun.getFunctionName), nullable = udf.getNullable, deterministic = fun.getDeterministic) } else { @@ -1753,7 +1753,7 @@ class SparkConnectPlanner( dataType = transformDataType(udf.getOutputType), inputEncoders = udfPacket.inputEncoders.map(e => Try(ExpressionEncoder(e)).toOption), outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)), - name = Option(fun.getFunctionName), + givenName = Option(fun.getFunctionName), nullable = udf.getNullable, deterministic = fun.getDeterministic) } @@ -1899,15 +1899,7 @@ class SparkConnectPlanner( fun: org.apache.spark.sql.expressions.UserDefinedFunction, exprs: Seq[Expression]): ScalaUDF = { val f = fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction] - ScalaUDF( - function = f.f, - dataType = f.dataType, - children = exprs, - inputEncoders = f.inputEncoders, - outputEncoder = f.outputEncoder, - udfName = f.name, - nullable = f.nullable, - udfDeterministic = f.deterministic) + UserDefinedFunctionUtils.toScalaUDF(f, exprs) } private def extractProtobufArgs(children: Seq[Expression]) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 72f5d0a2ba25f..cf64de3686a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -89,12 +89,19 @@ private[spark] object Column { @Stable class TypedColumn[-T, U]( node: ColumnNode, - private[sql] val encoder: ExpressionEncoder[U]) + private[sql] val encoder: Encoder[U], + private[sql] val inputType: Option[(ExpressionEncoder[_], Seq[Attribute])] = None) extends Column(node) { - // TODO get rid of this. - // This requires one or two more ColumnNodes... - def this(expr: Expression, encoder: ExpressionEncoder[U]) = this(Extension(expr), encoder) + override lazy val expr: Expression = { + val expression = internal.ColumnNodeToExpressionConverter(node) + inputType match { + case Some((inputEncoder, inputAttributes)) => + TypedAggUtils.withInputType(expression, inputEncoder, inputAttributes) + case None => + expression + } + } /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -103,8 +110,7 @@ class TypedColumn[-T, U]( private[sql] def withInputType( inputEncoder: ExpressionEncoder[_], inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { - val newExpr = TypedAggUtils.withInputType(expr, inputEncoder, inputAttributes) - new TypedColumn[T, U](newExpr, encoder) + new TypedColumn[T, U](node, encoder, Option((inputEncoder, inputAttributes))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c7511737b2b3f..94129d2e8b58b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1640,7 +1640,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { - implicit val encoder: ExpressionEncoder[U1] = c1.encoder + implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder) val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (!encoder.isSerializedAsStructForTopLevel) { @@ -1657,7 +1657,7 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(_.encoder) + val encoders = columns.map(c => encoderFor(c.encoder)) val namedColumns = columns.map(_.withInputType(exprEnc, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 52ab633cd75a7..a672f29966df7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -968,7 +968,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * that cast appropriately for the user facing interface. */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(_.encoder) + val encoders = columns.map(c => encoderFor(c.encoder)) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) val keyColumn = TypedAggUtils.aggKeyColumn(kExprEnc, groupingAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index d0d5beee9945a..7a8cfa5c9b623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,21 +20,19 @@ package org.apache.spark.sql import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag -import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction -import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction, UserDefinedFunctionUtils} +import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -50,8 +48,6 @@ import org.apache.spark.util.Utils @Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { - import UDFRegistration._ - protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" @@ -110,26 +106,36 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.2.0 */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { - udf.withName(name) match { + val named = udf.withName(name) + val builder: Seq[Expression] => Expression = named match { case udaf: UserDefinedAggregator[_, _, _] => - def builder(children: Seq[Expression]) = udaf.scalaAggregator(children) - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - udaf - case other => - def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - other + ScalaAggregator(udaf, _) + case udf: SparkUserDefinedFunction => + val expectedParameterCount = udf.inputEncoders.size + children => { + val actualParameterCount = children.length + if (expectedParameterCount == actualParameterCount) { + toScalaUDF(udf, children) + } else { + throw QueryCompilationErrors.wrongNumArgsError( + name, + expectedParameterCount.toString, + actualParameterCount) + } + } } + functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") + named } // scalastyle:off line.size.limit /* register 0-22 were generated by this script - (0 to 22).foreach { x => - val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputEncoders = (1 to x).foldRight("Nil")((i, s) => {s"Try(ExpressionEncoder[A$i]()).toOption :: $s"}) + (0 to 10).foreach { x => + val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") + val typeSeq = "RT" +: (1 to x).map(i => s"A$i") + val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -137,42 +143,53 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | * @since 1.3.0 | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - | val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders - | val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - | val finalUdf = if (nullable) udf else udf.asNonNullable() - | def builder(e: Seq[Expression]) = if (e.length == $x) { - | finalUdf.createScalaUDF(e) - | } else { - | throw QueryCompilationErrors.wrongNumArgsError(name, "$x", e.length) - | } - | functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - | finalUdf + | register(name, functions.udf(func)) |}""".stripMargin) } - (0 to 22).foreach { i => + (11 to 22).foreach { x => + val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") + val typeSeq = "RT" +: (1 to x).map(i => s"A$i") + val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") + val implicitTypeTags = typeSeq.map(t => s"implicitly[TypeTag[$t]]").mkString(", ") + println(s""" + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | register(name, UserDefinedFunctionUtils.toUDF(func, $implicitTypeTags)) + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val version = if (i == 0) "2.3.0" else "1.3.0" + println(s""" + |/** + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). + | * @since $version + | */ + |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { + | register(name, functions.udf(f, returnType)) + |}""".stripMargin) + } + + (11 to 22).foreach { i => val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") - val version = if (i == 0) "2.3.0" else "1.3.0" val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Register a deterministic Java UDF$i instance as user-defined function (UDF). - | * @since $version + | * @since 1.3.0 | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { - | val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) | val func = $funcCall - | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - | } else { - | throw QueryCompilationErrors.wrongNumArgsError(name, "$i", e.length) - | } - | functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + | register(name, UserDefinedFunctionUtils.toUDF(func, returnType, $i)) |}""".stripMargin) } */ @@ -183,18 +200,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 0) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "0", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -203,18 +209,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 1) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "1", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -223,18 +218,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 2) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "2", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -243,18 +227,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 3) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "3", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -263,18 +236,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 4) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "4", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -283,18 +245,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 5) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "5", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -303,18 +254,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 6) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "6", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -323,18 +263,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 7) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "7", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -343,18 +272,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 8) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "8", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -363,18 +281,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 9) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "9", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -383,18 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 10) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "10", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, functions.udf(func)) } /** @@ -403,18 +299,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 11) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "11", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]])) } /** @@ -423,18 +308,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 12) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "12", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]])) } /** @@ -443,18 +317,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 13) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "13", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]])) } /** @@ -463,18 +326,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 14) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "14", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]])) } /** @@ -483,18 +335,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 15) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "15", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]])) } /** @@ -503,18 +344,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 16) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "16", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]])) } /** @@ -523,18 +353,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 17) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "17", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]])) } /** @@ -543,18 +362,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 18) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "18", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]])) } /** @@ -563,18 +371,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 19) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "19", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]])) } /** @@ -583,18 +380,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 20) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "20", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]])) } /** @@ -603,18 +389,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 21) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "21", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]])) } /** @@ -623,18 +398,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) - val finalUdf = if (nullable) udf else udf.asNonNullable() - def builder(e: Seq[Expression]) = if (e.length == 22) { - finalUdf.createScalaUDF(e) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "22", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - finalUdf + register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]], implicitly[TypeTag[A22]])) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -697,12 +461,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends throw QueryCompilationErrors.udfClassWithTooManyTypeArgumentsError(n) } } catch { - case e @ (_: InstantiationException | _: IllegalArgumentException) => + case _: InstantiationException | _: IllegalArgumentException => throw QueryCompilationErrors.classWithoutPublicNonArgumentConstructorError(className) } } } catch { - case e: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) + case _: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) } } @@ -722,8 +486,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { - case e: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) - case e @ (_: InstantiationException | _: IllegalArgumentException) => + case _: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) + case _: InstantiationException | _: IllegalArgumentException => throw QueryCompilationErrors.classWithoutPublicNonArgumentConstructorError(className) } } @@ -733,14 +497,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = () => f.asInstanceOf[UDF0[Any]].call() - def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "0", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -748,14 +505,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "1", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -763,14 +513,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "2", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -778,14 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "3", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -793,14 +529,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "4", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -808,14 +537,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "5", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -823,14 +545,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "6", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -838,14 +553,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "7", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -853,14 +561,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "8", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -868,14 +569,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "9", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -883,14 +577,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) - val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "10", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, functions.udf(f, returnType)) } /** @@ -898,14 +585,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "11", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 11)) } /** @@ -913,14 +594,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "12", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 12)) } /** @@ -928,14 +603,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "13", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 13)) } /** @@ -943,14 +612,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "14", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 14)) } /** @@ -958,14 +621,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "15", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 15)) } /** @@ -973,14 +630,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "16", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 16)) } /** @@ -988,14 +639,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "17", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 17)) } /** @@ -1003,14 +648,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "18", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 18)) } /** @@ -1018,14 +657,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "19", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 19)) } /** @@ -1033,14 +666,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "20", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 20)) } /** @@ -1048,14 +675,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "21", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 21)) } /** @@ -1063,30 +684,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) - } else { - throw QueryCompilationErrors.wrongNumArgsError(name, "22", e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder, "java_udf") + register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 22)) } // scalastyle:on line.size.limit } - -private[sql] object UDFRegistration { - /** - * Obtaining the schema of output encoder for `ScalaUDF`. - * - * As the serialization in `ScalaUDF` is for individual column, not the whole row, - * we just take the data type of vanilla object serializer, not `serializer` which - * is transformed somehow for top-level row. - */ - def outputSchema(outputEncoder: ExpressionEncoder[_]): ScalaReflection.Schema = { - ScalaReflection.Schema(outputEncoder.objSerializer.dataType, - outputEncoder.objSerializer.nullable) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e517376bc5fc0..ffef4996fe052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedAggregator} import org.apache.spark.sql.types._ /** @@ -554,6 +554,21 @@ case class ScalaAggregator[IN, BUF, OUT]( copy(children = newChildren) } +object ScalaAggregator { + def apply[IN, BUF, OUT]( + uda: UserDefinedAggregator[IN, BUF, OUT], + children: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { + new ScalaAggregator( + children = children, + agg = uda.aggregator, + inputEncoder = encoderFor(uda.inputEncoder), + bufferEncoder = encoderFor(uda.aggregator.bufferEncoder), + nullable = uda.nullable, + isDeterministic = uda.deterministic, + aggregatorName = Option(uda.name)) + } +} + /** * An extension rule to resolve encoder expressions from a [[ScalaAggregator]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 88550fac7303f..1a2fbdc1fd116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.sql.{Encoder, TypedColumn} -import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} /** * A base class for user-defined aggregations, which can be used in `Dataset` operations to take @@ -50,7 +49,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @since 1.6.0 */ @SerialVersionUID(2093413866369130093L) -abstract class Aggregator[-IN, BUF, OUT] extends Serializable { +abstract class Aggregator[-IN, BUF, OUT] extends Serializable with UserDefinedFunctionLike { /** * A zero value for this aggregation. Should satisfy the property that any b + zero = b. @@ -94,11 +93,8 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable { * @since 1.6.0 */ def toColumn: TypedColumn[IN, OUT] = { - implicit val bEncoder = bufferEncoder - implicit val cEncoder = outputEncoder - - val expr = TypedAggregateExpression(this).toAggregateExpression() - - new TypedColumn[IN, OUT](expr, encoderFor[OUT]) + new TypedColumn[IN, OUT]( + InvokeInlineUserDefinedFunction(this, Nil), + outputEncoder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index a75384fb0f4e0..f1a02c026952c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.expressions +import scala.reflect.runtime.universe.TypeTag +import scala.util.Try + import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.aggregate.ScalaAggregator +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} import org.apache.spark.sql.types.DataType /** @@ -39,7 +44,7 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Stable -sealed abstract class UserDefinedFunction { +sealed abstract class UserDefinedFunction extends UserDefinedFunctionLike { /** * Returns true when the UDF can return a nullable value. @@ -62,7 +67,9 @@ sealed abstract class UserDefinedFunction { * @since 1.3.0 */ @scala.annotation.varargs - def apply(exprs: Column*): Column + def apply(exprs: Column*): Column = { + Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node))) + } /** * Updates UserDefinedFunction with a given name. @@ -89,31 +96,14 @@ sealed abstract class UserDefinedFunction { private[spark] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, - inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, - outputEncoder: Option[ExpressionEncoder[_]] = None, - name: Option[String] = None, + inputEncoders: Seq[Option[Encoder[_]]] = Nil, + outputEncoder: Option[Encoder[_]] = None, + givenName: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { - @scala.annotation.varargs - override def apply(exprs: Column*): Column = { - Column(createScalaUDF(exprs.map(_.expr))) - } - - private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = { - ScalaUDF( - f, - dataType, - exprs, - inputEncoders, - outputEncoder, - udfName = name, - nullable = nullable, - udfDeterministic = deterministic) - } - override def withName(name: String): SparkUserDefinedFunction = { - copy(name = Option(name)) + copy(givenName = Option(name)) } override def asNonNullable(): SparkUserDefinedFunction = { @@ -131,30 +121,19 @@ private[spark] case class SparkUserDefinedFunction( copy(deterministic = false) } } + + override def name: String = givenName.getOrElse("UDF") } private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( aggregator: Aggregator[IN, BUF, OUT], inputEncoder: Encoder[IN], - name: Option[String] = None, + givenName: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { - @scala.annotation.varargs - def apply(exprs: Column*): Column = { - Column(scalaAggregator(exprs.map(_.expr)).toAggregateExpression()) - } - - // This is also used by udf.register(...) when it detects a UserDefinedAggregator - def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { - val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]] - val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]] - ScalaAggregator( - exprs, aggregator, iEncoder, bEncoder, nullable, deterministic, aggregatorName = name) - } - override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { - copy(name = Option(name)) + copy(givenName = Option(name)) } override def asNonNullable(): UserDefinedAggregator[IN, BUF, OUT] = { @@ -172,4 +151,53 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( copy(deterministic = false) } } + + override def name: String = givenName.getOrElse(super.name) +} + +private[sql] object UserDefinedFunctionUtils { + private[sql] def toUDF( + function: AnyRef, + returnTypeTag: TypeTag[_], + inputTypeTags: TypeTag[_]*): SparkUserDefinedFunction = { + val outputEncoder = ScalaReflection.encoderFor(returnTypeTag) + val inputEncoders = inputTypeTags.map { tag => + Try(ScalaReflection.encoderFor(tag)).toOption + } + SparkUserDefinedFunction( + f = function, + inputEncoders = inputEncoders, + dataType = outputEncoder.dataType, + outputEncoder = Option(outputEncoder), + nullable = outputEncoder.nullable) + } + + private[sql] def toUDF( + function: AnyRef, + returnType: DataType, + cardinality: Int): SparkUserDefinedFunction = { + SparkUserDefinedFunction( + function, + CharVarcharUtils.failIfHasCharVarchar(returnType), + inputEncoders = Seq.fill(cardinality)(None), + None) + } + + /** + * Create a [[ScalaUDF]]. + * + * This function should be moved to [[ScalaUDF]] when we move [[SparkUserDefinedFunction]] + * to sql/api. + */ + def toScalaUDF(udf: SparkUserDefinedFunction, children: Seq[Expression]): ScalaUDF = { + ScalaUDF( + udf.f, + udf.dataType, + children, + udf.inputEncoders.map(_.map(encoderFor(_))), + udf.outputEncoder.map(encoderFor(_)), + udfName = udf.givenName, + nullable = udf.nullable, + udfDeterministic = udf.deterministic) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index b387695ef2379..a4aa9c312aff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} import org.apache.spark.sql.types._ /** @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ @Stable @deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" + " via the functions.udaf(agg) method.", "3.0.0") -abstract class UserDefinedAggregateFunction extends Serializable { +abstract class UserDefinedAggregateFunction extends Serializable with UserDefinedFunctionLike { /** * A `StructType` represents data types of input arguments of this aggregate function. @@ -130,8 +130,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { */ @scala.annotation.varargs def apply(exprs: Column*): Column = { - val aggregateExpression = ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression() - Column(aggregateExpression) + Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node))) } /** @@ -142,9 +141,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { */ @scala.annotation.varargs def distinct(exprs: Column*): Column = { - val aggregateExpression = - ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression(isDistinct = true) - Column(aggregateExpression) + Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node), isDistinct = true)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 89a873deb7985..63466a758b2dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,14 +21,13 @@ import java.util.Collections import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag -import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction, UserDefinedFunctionUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -7881,9 +7880,10 @@ object functions { /* Use the following code to generate: (0 to 10).foreach { x => - val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputEncoders = (1 to x).foldRight("Nil")((i, s) => {s"Try(ExpressionEncoder[A$i]()).toOption :: $s"}) + val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") + val typeSeq = "RT" +: (1 to x).map(i => s"A$i") + val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") + val implicitTypeTags = typeSeq.map(t => s"implicitly[TypeTag[$t]]").mkString(", ") println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -7895,11 +7895,7 @@ object functions { | * @since 1.3.0 | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - | val inputEncoders = $inputEncoders - | val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - | if (nullable) udf else udf.asNonNullable() + | UserDefinedFunctionUtils.toUDF(f, $implicitTypeTags) |}""".stripMargin) } @@ -7921,7 +7917,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = $funcCall - | SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill($i)(None)) + | UserDefinedFunctionUtils.toUDF(func, returnType, $i) |}""".stripMargin) } @@ -8004,11 +8000,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]]) } /** @@ -8021,11 +8013,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) } /** @@ -8038,11 +8026,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) } /** @@ -8055,11 +8039,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) } /** @@ -8072,11 +8052,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) } /** @@ -8089,11 +8065,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) } /** @@ -8106,11 +8078,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) } /** @@ -8123,11 +8091,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) } /** @@ -8140,11 +8104,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) } /** @@ -8157,11 +8117,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) } /** @@ -8174,11 +8130,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) - val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil - val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) - if (nullable) udf else udf.asNonNullable() + UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -8196,7 +8148,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = () => f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(0)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 0) } /** @@ -8210,7 +8162,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(1)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 1) } /** @@ -8224,7 +8176,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(2)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 2) } /** @@ -8238,7 +8190,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(3)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 3) } /** @@ -8252,7 +8204,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(4)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 4) } /** @@ -8266,7 +8218,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(5)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 5) } /** @@ -8280,7 +8232,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(6)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 6) } /** @@ -8294,7 +8246,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(7)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 7) } /** @@ -8308,7 +8260,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(8)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 8) } /** @@ -8322,7 +8274,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(9)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 9) } /** @@ -8336,7 +8288,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction(func, returnType, inputEncoders = Seq.fill(10)(None)) + UserDefinedFunctionUtils.toUDF(func, returnType, 10) } // scalastyle:on parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 97be9526849ae..96385ad03a7ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction +import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType, VariantType} import org.apache.spark.util.ArrayImplicits._ @@ -1576,7 +1577,7 @@ object IntegratedUDFTestUtils extends SQLHelper { }, StringType, inputEncoders = Seq.fill(1)(None), - name = Some(name)) { + givenName = Some(name)) { override def apply(exprs: Column*): Column = { assert(exprs.length == 1, "Defined UDF only has one column") @@ -1586,7 +1587,7 @@ object IntegratedUDFTestUtils extends SQLHelper { "as input. Try df(name) or df.col(name)") expr.dataType } - Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), rt)) + Column(Cast(toScalaUDF(this, Cast(expr, StringType) :: Nil), rt)) } override def withName(name: String): TestInternalScalaUDF = { From ea07c58b2fca6c3e97f2b54ac2f5dcdd720f3f04 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 12 Aug 2024 21:03:01 -0400 Subject: [PATCH 07/18] Regular Fixes --- .../spark/ml/feature/StringIndexer.scala | 2 +- project/MimaExcludes.scala | 5 +- python/pyspark/sql/functions/builtin.py | 60 ++++++++----------- .../spark/sql/internal/columnNodes.scala | 33 +++++++--- .../scala/org/apache/spark/sql/Column.scala | 8 ++- .../spark/sql/expressions/WindowSpec.scala | 2 +- .../sql/internal/columnNodeSupport.scala | 12 ++-- .../spark/sql/ColumnExpressionSuite.scala | 8 +-- .../errors/QueryExecutionErrorsSuite.scala | 2 +- ...ColumnNodeToExpressionConverterSuite.scala | 30 ++++++++-- 10 files changed, 101 insertions(+), 61 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cfc6b8e0e5ac6..b50e7d13cdee6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -198,7 +198,7 @@ class StringIndexer @Since("1.4.0") ( } else { // We don't count for NaN values. Because `StringIndexerAggregator` only processes strings, // we replace NaNs with null in advance. - nanvl(col, lit(null)).cast(StringType) + when(!isnan(col), col).cast(StringType) } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 20e50469e8568..03bf9c89aa2dc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -107,7 +107,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.JobWaiter.cancel"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.cancel"), // SPARK-48901: Add clusterBy() to DataStreamWriter. - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"), + // SPARK-49022: Use Column API + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this") ) // Default exclude rules diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 2828c0b46f161..5e64463e73374 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -17559,7 +17559,7 @@ def array_sort( if comparator is None: return _invoke_function_over_columns("array_sort", col) else: - return _invoke_higher_order_function("ArraySort", [col], [comparator]) + return _invoke_higher_order_function("array_sort", [col], [comparator]) @_try_remote_functions @@ -18560,7 +18560,7 @@ def from_csv( ) -def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: +def _unresolved_named_lambda_variable(name: str) -> Column: """ Create `o.a.s.sql.expressions.UnresolvedNamedLambdaVariable`, convert it to o.s.sql.Column and wrap in Python `Column` @@ -18573,13 +18573,11 @@ def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: name_parts : str """ from py4j.java_gateway import JVMView - from pyspark.sql.classic.column import _to_seq sc = _get_active_spark_context() - name_parts_seq = _to_seq(sc, name_parts) - expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions + internal = cast(JVMView, sc._jvm).org.apache.spark.sql.internal return Column( - cast(JVMView, sc._jvm).Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq)) + cast(JVMView, sc._jvm).Column(internal.UnresolvedNamedLambdaVariable.apply(name)) ) @@ -18629,13 +18627,11 @@ def _create_lambda(f: Callable) -> Callable: parameters = _get_lambda_parameters(f) sc = _get_active_spark_context() - expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions + internal = cast(JVMView, sc._jvm).org.apache.spark.sql.internal argnames = ["x", "y", "z"] args = [ - _unresolved_named_lambda_variable( - expressions.UnresolvedNamedLambdaVariable.freshVarName(arg) - ) + _unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)] ] @@ -18647,10 +18643,10 @@ def _create_lambda(f: Callable) -> Callable: messageParameters={"func_name": f.__name__, "return_type": type(result).__name__}, ) - jexpr = result._jc.expr() - jargs = _to_seq(sc, [arg._jc.expr() for arg in args]) + jexpr = result._jc.node() + jargs = _to_seq(sc, [arg._jc.node() for arg in args]) - return expressions.LambdaFunction(jexpr, jargs, False) + return cast(JVMView, sc._jvm).Column(internal.LambdaFunction.apply(jexpr, jargs)) def _invoke_higher_order_function( @@ -18669,18 +18665,12 @@ def _invoke_higher_order_function( :return: a Column """ - from py4j.java_gateway import JVMView - from pyspark.sql.classic.column import _to_java_column + from pyspark.sql.classic.column import _to_seq, _to_java_column sc = _get_active_spark_context() - expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions - expr = getattr(expressions, name) - - jcols = [_to_java_column(col).expr() for col in cols] jfuns = [_create_lambda(f) for f in funs] - - return Column(cast(JVMView, sc._jvm).Column(expr(*jcols + jfuns))) - + jcols = [_to_java_column(c) for c in cols] + return Column(sc._jvm.Column.pysparkFn(name, _to_seq(sc, jcols + jfuns))) @overload def transform(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: @@ -18747,7 +18737,7 @@ def transform( |[1, -2, 3, -4]| +--------------+ """ - return _invoke_higher_order_function("ArrayTransform", [col], [f]) + return _invoke_higher_order_function("transform", [col], [f]) @_try_remote_functions @@ -18788,7 +18778,7 @@ def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: | true| +------------+ """ - return _invoke_higher_order_function("ArrayExists", [col], [f]) + return _invoke_higher_order_function("exists", [col], [f]) @_try_remote_functions @@ -18833,7 +18823,7 @@ def forall(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: | true| +-------+ """ - return _invoke_higher_order_function("ArrayForAll", [col], [f]) + return _invoke_higher_order_function("forall", [col], [f]) @overload @@ -18900,7 +18890,7 @@ def filter( |[2018-09-20, 2019-07-01]| +------------------------+ """ - return _invoke_higher_order_function("ArrayFilter", [col], [f]) + return _invoke_higher_order_function("filter", [col], [f]) @_try_remote_functions @@ -18973,10 +18963,10 @@ def aggregate( +----+ """ if finish is not None: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge, finish]) + return _invoke_higher_order_function("array_agg", [col, initialValue], [merge, finish]) else: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge]) + return _invoke_higher_order_function("array_agg", [col, initialValue], [merge]) @_try_remote_functions @@ -19046,10 +19036,10 @@ def reduce( +----+ """ if finish is not None: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge, finish]) + return _invoke_higher_order_function("reduce", [col, initialValue], [merge, finish]) else: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge]) + return _invoke_higher_order_function("reduce", [col, initialValue], [merge]) @_try_remote_functions @@ -19104,7 +19094,7 @@ def zip_with( |[foo_1, bar_2, 3]| +-----------------+ """ - return _invoke_higher_order_function("ZipWith", [left, right], [f]) + return _invoke_higher_order_function("zip_with", [left, right], [f]) @_try_remote_functions @@ -19144,7 +19134,7 @@ def transform_keys(col: "ColumnOrName", f: Callable[[Column, Column], Column]) - >>> sorted(row["data_upper"].items()) [('BAR', 2.0), ('FOO', -2.0)] """ - return _invoke_higher_order_function("TransformKeys", [col], [f]) + return _invoke_higher_order_function("transform_keys", [col], [f]) @_try_remote_functions @@ -19184,7 +19174,7 @@ def transform_values(col: "ColumnOrName", f: Callable[[Column, Column], Column]) >>> sorted(row["new_data"].items()) [('IT', 20.0), ('OPS', 34.0), ('SALES', 2.0)] """ - return _invoke_higher_order_function("TransformValues", [col], [f]) + return _invoke_higher_order_function("transform_values", [col], [f]) @_try_remote_functions @@ -19247,7 +19237,7 @@ def map_filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Co >>> sorted(row["data_filtered"].items()) [('baz', 32.0)] """ - return _invoke_higher_order_function("MapFilter", [col], [f]) + return _invoke_higher_order_function("map_filter", [col], [f]) @_try_remote_functions @@ -19328,7 +19318,7 @@ def map_zip_with( >>> sorted(row["updated_data"].items()) [('A', 1), ('B', 5), ('C', None)] """ - return _invoke_higher_order_function("MapZipWith", [col1, col2], [f]) + return _invoke_higher_order_function("map_zip_with", [col1, col2], [f]) @_try_remote_functions diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 288cb37faa97b..d07dc61307ad8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.internal +import java.util.concurrent.atomic.AtomicLong + import ColumnNode._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -74,10 +76,10 @@ private[internal] object ColumnNode { nodes.map(_.normalize().asInstanceOf[T]) def argumentsToSql(nodes: Seq[ColumnNodeLike]): String = textArgumentsToSql(nodes.map(_.sql)) - def textArgumentsToSql(parts: Seq[String]): String = parts.mkString("(", ",", ")") + def textArgumentsToSql(parts: Seq[String]): String = parts.mkString("(", ", ", ")") def elementsToSql(elements: Seq[ColumnNodeLike], prefix: String = ""): String = { if (elements.nonEmpty) { - elements.map(_.sql).mkString(prefix, ",", "") + elements.map(_.sql).mkString(prefix, ", ", "") } else { "" } @@ -263,8 +265,8 @@ private[sql] case class SortOrder( private[sql] object SortOrder { sealed abstract class SortDirection(override val sql: String) extends ColumnNodeLike - object Ascending extends SortDirection("ASCENDING") - object Descending extends SortDirection("DESCENDING") + object Ascending extends SortDirection("ASC") + object Descending extends SortDirection("DESC") sealed abstract class NullOrdering(override val sql: String) extends ColumnNodeLike object NullsFirst extends NullOrdering("NULLS FIRST") object NullsLast extends NullOrdering("NULLS LAST") @@ -351,7 +353,8 @@ private[sql] object WindowFrame { private[sql] case class LambdaFunction( function: ColumnNode, arguments: Seq[UnresolvedNamedLambdaVariable], - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin) extends ColumnNode { + override private[internal] def normalize(): LambdaFunction = copy( function = function.normalize(), arguments = ColumnNode.normalize(arguments), @@ -360,6 +363,12 @@ private[sql] case class LambdaFunction( override def sql: String = argumentsToSql(arguments) + " -> " + function.sql } +object LambdaFunction { + def apply(function: ColumnNode, arguments: Seq[UnresolvedNamedLambdaVariable]): LambdaFunction = ( + new LambdaFunction(function, arguments, CurrentOrigin.get) + ) +} + /** * Variable used in a [[LambdaFunction]]. * @@ -374,6 +383,14 @@ private[sql] case class UnresolvedNamedLambdaVariable( override def sql: String = name } +object UnresolvedNamedLambdaVariable { + private val nextId = new AtomicLong() + def apply(name: String): UnresolvedNamedLambdaVariable = { + // Generate a unique name because we reuse lambda variable names (e.g. x, y, or z). + new UnresolvedNamedLambdaVariable(s"${name}_${nextId.incrementAndGet()}") + } +} + /** * Extract a value from a complex type. This can be a field from a struct, a value from a map, * or an element from an array. @@ -436,9 +453,9 @@ private[sql] case class CaseWhenOtherwise( origin = NO_ORIGIN) override def sql: String = - "CASE " + - branches.map(cv => s"WHEN ${cv._1.sql} THEN ${cv._2.sql}").mkString(" ") + - otherwise.map(o => s"ELSE ${o.sql}") + + "CASE" + + branches.map(cv => s" WHEN ${cv._1.sql} THEN ${cv._2.sql}").mkString + + otherwise.map(o => s" ELSE ${o.sql}").getOrElse("") + " END" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index cf64de3686a53..22931978bc122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ -import org.apache.spark.annotation.Stable +import org.apache.spark.annotation.{Private, Stable} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{LEFT_EXPR, RIGHT_EXPR} import org.apache.spark.sql.catalyst.analysis._ @@ -74,6 +74,12 @@ private[spark] object Column { isDistinct = isDistinct, isInternal = isInternal)) } + + /** + * Hook used by pyspark to create functions. + */ + @Private + def pysparkFn(name: String, args: Seq[Column]): Column = fn(name, args: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c8d7f538df43c..7da8b8dbd4b9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -68,7 +68,7 @@ class WindowSpec private[sql]( */ @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { - new WindowSpec(cols.map(_.node), cols.map(_.sortOrder), frame) + new WindowSpec(partitionSpec, cols.map(_.sortOrder), frame) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 93f89679ddac9..cef111aa300b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkException -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -42,7 +41,8 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres override def apply(node: ColumnNode): Expression = CurrentOrigin.withOrigin(node.origin) { node match { case Literal(value, Some(dataType), _) => - expressions.Literal.create(value, dataType) + val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + expressions.Literal(converter(value), dataType) case Literal(value, None, _) => expressions.Literal(value) @@ -78,6 +78,10 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres isDistinct = isDistinct, isInternal = isInternal) + case Alias(child, Seq(name), None, _) => + expressions.Alias(apply(child), name)( + nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) + case Alias(child, Seq(name), metadata, _) => expressions.Alias(apply(child), name)(explicitMetadata = metadata) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 936bcc21b763d..64b5128872610 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -978,15 +978,15 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-37646: lit") { assert(lit($"foo") == $"foo") assert(lit($"foo") == $"foo") - assert(lit(1) == Column(Literal(1))) - assert(lit(null) == Column(Literal(null, NullType))) + assert(lit(1).expr == Column(Literal(1)).expr) + assert(lit(null).expr == Column(Literal(null, NullType)).expr) } test("typedLit") { assert(typedLit($"foo") == $"foo") assert(typedLit($"foo") == $"foo") - assert(typedLit(1) == Column(Literal(1))) - assert(typedLit[String](null) == Column(Literal(null, StringType))) + assert(typedLit(1).expr == Column(Literal(1)).expr) + assert(typedLit[String](null).expr == Column(Literal(null, StringType)).expr) val df = Seq(Tuple1(0)).toDF("a") // Only check the types `lit` cannot handle diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 4a748d590feb1..b9f4e82cdd3c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -211,7 +211,7 @@ class QueryExecutionErrorsSuite test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") { def checkUnsupportedTypeInLiteral(v: Any, literal: String, dataType: String): Unit = { checkError( - exception = intercept[SparkRuntimeException] { lit(v) }, + exception = intercept[SparkRuntimeException] { lit(v).expr }, errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", parameters = Map("value" -> literal, "type" -> dataType), sqlState = "0A000") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index 50cb7f437adb1..e762e62a2f711 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.internal import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.catalyst.{analysis, expressions, InternalRow} import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -29,6 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for [[ColumnNode]] to [[Expression]] conversions. @@ -58,6 +58,8 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { d.copy(inputAttributes = d.inputAttributes.map(_.withExprId(ExprId(0)))) case a: expressions.aggregate.AggregateExpression => a.copy(resultId = ExprId(0)) + case expressions.UnresolvedNamedLambdaVariable(Seq(name)) => + expressions.UnresolvedNamedLambdaVariable(name.takeWhile(_ != '_') :: Nil) } test("literal") { @@ -65,6 +67,17 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { testConversion( Literal("foo", Option(StringType)), expressions.Literal.create("foo", StringType)) + val value = (12.0, "north", 60.0, "west") + val dataType = new StructType() + .add("_1", DoubleType) + .add("_2", StringType) + .add("_3", DoubleType) + .add("_4", StringType) + testConversion( + Literal((12.0, "north", 60.0, "west"), Option(dataType)), + expressions.Literal( + InternalRow(12.0, UTF8String.fromString("north"), 60.0, UTF8String.fromString("west")), + dataType)) } test("attribute") { @@ -123,7 +136,13 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("alias") { testConversion( Alias(Literal("qwe"), "newA" :: Nil), - expressions.Alias(expressions.Literal("qwe"), "newA")()) + expressions.Alias(expressions.Literal("qwe"), "newA")( + nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY))) + val metadata = new MetadataBuilder().putLong("q", 10).build() + testConversion( + Alias(UnresolvedAttribute("a"), "b" :: Nil, Option(metadata)), + expressions.Alias(analysis.UnresolvedAttribute("a"), "b")( + explicitMetadata = Option(metadata))) testConversion( Alias(UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), analysis.MultiAlias(analysis.UnresolvedAttribute("complex"), Seq("newA", "newB"))) @@ -332,7 +351,8 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { UserDefinedAggregator( aggregator = int2LongSum, inputEncoder = AgnosticEncoders.PrimitiveIntEncoder, - nullable = false), + nullable = false, + givenName = Option("int2LongSum")), UnresolvedAttribute("i_col") :: Nil), aggregate.ScalaAggregator( children = analysis.UnresolvedAttribute("i_col") :: Nil, From 73b1812d741aad91b6754e160c4c68d61fe1c0c6 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 12 Aug 2024 21:03:10 -0400 Subject: [PATCH 08/18] UDF Fixes --- .../catalyst/encoders/AgnosticEncoder.scala | 6 - .../apache/spark/sql/UDFRegistration.scala | 177 ++++++++++-------- .../sql/expressions/UserDefinedFunction.scala | 64 ++++--- .../org/apache/spark/sql/functions.scala | 56 +++--- .../spark/sql/IntegratedUDFTestUtils.scala | 47 ++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 3 +- 6 files changed, 179 insertions(+), 174 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 039ee5ad1e224..9133abce88adc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -48,12 +48,6 @@ trait AgnosticEncoder[T] extends Encoder[T] { } object AgnosticEncoders { - object UnboundEncoder extends AgnosticEncoder[Any] { - override def isPrimitive: Boolean = false - override def dataType: DataType = NullType - override def clsTag: ClassTag[Any] = classTag[Any] - } - case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E]) extends AgnosticEncoder[Option[E]] { override def isPrimitive: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7a8cfa5c9b623..5eec2a6878b5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -28,10 +28,11 @@ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction -import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction, UserDefinedFunctionUtils} +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -106,11 +107,38 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.2.0 */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { + register(name, udf, "scala_udf", validateParameterCount = false) + } + + private def registerScalaUDF( + name: String, + func: AnyRef, + returnTypeTag: TypeTag[_], + inputTypeTags: TypeTag[_]*): UserDefinedFunction = { + val udf = SparkUserDefinedFunction(func, returnTypeTag, inputTypeTags: _*) + register(name, udf, "scala_udf", validateParameterCount = true) + } + + private def registerJavaUDF( + name: String, + func: AnyRef, + returnDataType: DataType, + cardinality: Int): UserDefinedFunction = { + val validatedReturnDataType = CharVarcharUtils.failIfHasCharVarchar(returnDataType) + val udf = SparkUserDefinedFunction(func, validatedReturnDataType, cardinality) + register(name, udf, "java_udf", validateParameterCount = true) + } + + private def register( + name: String, + udf: UserDefinedFunction, + source: String, + validateParameterCount: Boolean): UserDefinedFunction = { val named = udf.withName(name) val builder: Seq[Expression] => Expression = named match { case udaf: UserDefinedAggregator[_, _, _] => ScalaAggregator(udaf, _) - case udf: SparkUserDefinedFunction => + case udf: SparkUserDefinedFunction if validateParameterCount => val expectedParameterCount = udf.inputEncoders.size children => { val actualParameterCount = children.length @@ -123,8 +151,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends actualParameterCount) } } + case udf: SparkUserDefinedFunction => + toScalaUDF(udf, _) } - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") + functionRegistry.createOrReplaceTempFunction(name, builder, source) named } @@ -132,22 +162,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 10).foreach { x => - val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") - val typeSeq = "RT" +: (1 to x).map(i => s"A$i") - val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") - println(s""" - |/** - | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). - | * @tparam RT return type of UDF. - | * @since 1.3.0 - | */ - |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - | register(name, functions.udf(func)) - |}""".stripMargin) - } - - (11 to 22).foreach { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => s"A$i, $s") val typeSeq = "RT" +: (1 to x).map(i => s"A$i") val typeTags = typeSeq.map(t => s"$t: TypeTag").mkString(", ") @@ -159,24 +174,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | * @since 1.3.0 | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - | register(name, UserDefinedFunctionUtils.toUDF(func, $implicitTypeTags)) + | registerScalaUDF(name, func, $implicitTypeTags) |}""".stripMargin) } - (0 to 10).foreach { i => - val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") - val version = if (i == 0) "2.3.0" else "1.3.0" - println(s""" - |/** - | * Register a deterministic Java UDF$i instance as user-defined function (UDF). - | * @since $version - | */ - |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { - | register(name, functions.udf(f, returnType)) - |}""".stripMargin) - } - - (11 to 22).foreach { i => + (0 to 22).foreach { i => val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" @@ -189,7 +191,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | val func = $funcCall - | register(name, UserDefinedFunctionUtils.toUDF(func, returnType, $i)) + | registerJavaUDF(name, func, returnType, $i) |}""".stripMargin) } */ @@ -200,7 +202,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]]) } /** @@ -209,7 +211,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) } /** @@ -218,7 +220,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) } /** @@ -227,7 +229,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) } /** @@ -236,7 +238,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) } /** @@ -245,7 +247,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) } /** @@ -254,7 +256,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) } /** @@ -263,7 +265,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) } /** @@ -272,7 +274,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) } /** @@ -281,7 +283,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) } /** @@ -290,7 +292,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - register(name, functions.udf(func)) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) } /** @@ -299,7 +301,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]]) } /** @@ -308,7 +310,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]]) } /** @@ -317,7 +319,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]]) } /** @@ -326,7 +328,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]]) } /** @@ -335,7 +337,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]]) } /** @@ -344,7 +346,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]]) } /** @@ -353,7 +355,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]]) } /** @@ -362,7 +364,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]]) } /** @@ -371,7 +373,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]]) } /** @@ -380,7 +382,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]]) } /** @@ -389,7 +391,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]]) } /** @@ -398,7 +400,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { - register(name, UserDefinedFunctionUtils.toUDF(func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]], implicitly[TypeTag[A22]])) + registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]], implicitly[TypeTag[A22]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -494,10 +496,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /** * Register a deterministic Java UDF0 instance as user-defined function (UDF). - * @since 2.3.0 + * @since 1.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = () => f.asInstanceOf[UDF0[Any]].call() + registerJavaUDF(name, func, returnType, 0) } /** @@ -505,7 +508,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + registerJavaUDF(name, func, returnType, 1) } /** @@ -513,7 +517,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + registerJavaUDF(name, func, returnType, 2) } /** @@ -521,7 +526,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 3) } /** @@ -529,7 +535,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 4) } /** @@ -537,7 +544,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 5) } /** @@ -545,7 +553,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 6) } /** @@ -553,7 +562,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 7) } /** @@ -561,7 +571,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 8) } /** @@ -569,7 +580,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 9) } /** @@ -577,7 +589,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - register(name, functions.udf(f, returnType)) + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + registerJavaUDF(name, func, returnType, 10) } /** @@ -586,7 +599,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 11)) + registerJavaUDF(name, func, returnType, 11) } /** @@ -595,7 +608,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 12)) + registerJavaUDF(name, func, returnType, 12) } /** @@ -604,7 +617,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 13)) + registerJavaUDF(name, func, returnType, 13) } /** @@ -613,7 +626,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 14)) + registerJavaUDF(name, func, returnType, 14) } /** @@ -622,7 +635,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 15)) + registerJavaUDF(name, func, returnType, 15) } /** @@ -631,7 +644,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 16)) + registerJavaUDF(name, func, returnType, 16) } /** @@ -640,7 +653,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 17)) + registerJavaUDF(name, func, returnType, 17) } /** @@ -649,7 +662,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 18)) + registerJavaUDF(name, func, returnType, 18) } /** @@ -658,7 +671,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 19)) + registerJavaUDF(name, func, returnType, 19) } /** @@ -667,7 +680,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 20)) + registerJavaUDF(name, func, returnType, 20) } /** @@ -676,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 21)) + registerJavaUDF(name, func, returnType, 21) } /** @@ -685,7 +698,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - register(name, UserDefinedFunctionUtils.toUDF(func, returnType, 22)) + registerJavaUDF(name, func, returnType, 22) } // scalastyle:on line.size.limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index f1a02c026952c..e3209cf7288ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -23,9 +23,8 @@ import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoders} import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} import org.apache.spark.sql.types.DataType @@ -125,6 +124,35 @@ private[spark] case class SparkUserDefinedFunction( override def name: String = givenName.getOrElse("UDF") } +object SparkUserDefinedFunction { + private[sql] def apply( + function: AnyRef, + returnTypeTag: TypeTag[_], + inputTypeTags: TypeTag[_]*): SparkUserDefinedFunction = { + val outputEncoder = ScalaReflection.encoderFor(returnTypeTag) + val inputEncoders = inputTypeTags.map { tag => + Try(ScalaReflection.encoderFor(tag)).toOption + } + SparkUserDefinedFunction( + f = function, + inputEncoders = inputEncoders, + dataType = outputEncoder.dataType, + outputEncoder = Option(outputEncoder), + nullable = outputEncoder.nullable) + } + + private[sql] def apply( + function: AnyRef, + returnType: DataType, + cardinality: Int): SparkUserDefinedFunction = { + SparkUserDefinedFunction( + function, + returnType, + inputEncoders = Seq.fill(cardinality)(None), + None) + } +} + private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( aggregator: Aggregator[IN, BUF, OUT], inputEncoder: Encoder[IN], @@ -156,33 +184,6 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( } private[sql] object UserDefinedFunctionUtils { - private[sql] def toUDF( - function: AnyRef, - returnTypeTag: TypeTag[_], - inputTypeTags: TypeTag[_]*): SparkUserDefinedFunction = { - val outputEncoder = ScalaReflection.encoderFor(returnTypeTag) - val inputEncoders = inputTypeTags.map { tag => - Try(ScalaReflection.encoderFor(tag)).toOption - } - SparkUserDefinedFunction( - f = function, - inputEncoders = inputEncoders, - dataType = outputEncoder.dataType, - outputEncoder = Option(outputEncoder), - nullable = outputEncoder.nullable) - } - - private[sql] def toUDF( - function: AnyRef, - returnType: DataType, - cardinality: Int): SparkUserDefinedFunction = { - SparkUserDefinedFunction( - function, - CharVarcharUtils.failIfHasCharVarchar(returnType), - inputEncoders = Seq.fill(cardinality)(None), - None) - } - /** * Create a [[ScalaUDF]]. * @@ -194,7 +195,10 @@ private[sql] object UserDefinedFunctionUtils { udf.f, udf.dataType, children, - udf.inputEncoders.map(_.map(encoderFor(_))), + udf.inputEncoders.map(_.collect { + // At some point it would be nice if were to support this. + case e if e != AgnosticEncoders.UnboundRowEncoder => encoderFor(e) + }), udf.outputEncoder.map(encoderFor(_)), udfName = udf.givenName, nullable = udf.nullable, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 63466a758b2dc..f5f9458b5b1f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -25,9 +25,9 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction, UserDefinedFunctionUtils} +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -427,7 +427,7 @@ object functions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(ExpressionEncoder[Long]()) + count(Column(columnName)).as(AgnosticEncoders.PrimitiveLongEncoder) /** * Aggregate function: returns the number of distinct items in a group. @@ -7895,7 +7895,7 @@ object functions { | * @since 1.3.0 | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - | UserDefinedFunctionUtils.toUDF(f, $implicitTypeTags) + | SparkUserDefinedFunction(f, $implicitTypeTags) |}""".stripMargin) } @@ -7917,7 +7917,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = $funcCall - | UserDefinedFunctionUtils.toUDF(func, returnType, $i) + | SparkUserDefinedFunction(func, returnType, $i) |}""".stripMargin) } @@ -7953,7 +7953,7 @@ object functions { * @note The input encoder is inferred from the input type IN. */ def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = { - udaf(agg, ExpressionEncoder[IN]()) + udaf(agg, ScalaReflection.encoderFor[IN]) } /** @@ -8000,7 +8000,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]]) } /** @@ -8013,7 +8013,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) } /** @@ -8026,7 +8026,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) } /** @@ -8039,7 +8039,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) } /** @@ -8052,7 +8052,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) } /** @@ -8065,7 +8065,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) } /** @@ -8078,7 +8078,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) } /** @@ -8091,7 +8091,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) } /** @@ -8104,7 +8104,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) } /** @@ -8117,7 +8117,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) } /** @@ -8130,7 +8130,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - UserDefinedFunctionUtils.toUDF(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) + SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -8148,7 +8148,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = () => f.asInstanceOf[UDF0[Any]].call() - UserDefinedFunctionUtils.toUDF(func, returnType, 0) + SparkUserDefinedFunction(func, returnType, 0) } /** @@ -8162,7 +8162,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 1) + SparkUserDefinedFunction(func, returnType, 1) } /** @@ -8176,7 +8176,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 2) + SparkUserDefinedFunction(func, returnType, 2) } /** @@ -8190,7 +8190,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 3) + SparkUserDefinedFunction(func, returnType, 3) } /** @@ -8204,7 +8204,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 4) + SparkUserDefinedFunction(func, returnType, 4) } /** @@ -8218,7 +8218,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 5) + SparkUserDefinedFunction(func, returnType, 5) } /** @@ -8232,7 +8232,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 6) + SparkUserDefinedFunction(func, returnType, 6) } /** @@ -8246,7 +8246,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 7) + SparkUserDefinedFunction(func, returnType, 7) } /** @@ -8260,7 +8260,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 8) + SparkUserDefinedFunction(func, returnType, 8) } /** @@ -8274,7 +8274,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 9) + SparkUserDefinedFunction(func, returnType, 9) } /** @@ -8288,7 +8288,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunctionUtils.toUDF(func, returnType, 10) + SparkUserDefinedFunction(func, returnType, 10) } // scalastyle:on parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 96385ad03a7ed..070f46dc74f01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -28,6 +28,7 @@ import org.scalatest.Assertions._ import org.apache.spark.TestUtils import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils, SimplePythonFunction} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource @@ -1567,40 +1568,30 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema("col").dataType) * }}} */ - class TestInternalScalaUDF( - name: String, - returnType: Option[DataType] = None) extends SparkUserDefinedFunction( - (input: Any) => if (input == null) { - null - } else { - input.toString - }, - StringType, - inputEncoders = Seq.fill(1)(None), - givenName = Some(name)) { + case class TestScalaUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { + private val udf: SparkUserDefinedFunction = { + val unnamed = functions.udf { (input: Any) => + if (input == null) { + null + } else { + input.toString + } + } + unnamed.withName(name).asInstanceOf[SparkUserDefinedFunction] + } - override def apply(exprs: Column*): Column = { + val builder: FunctionRegistry.FunctionBuilder = { exprs => assert(exprs.length == 1, "Defined UDF only has one column") - val expr = exprs.head.expr + val expr = exprs.head val rt = returnType.getOrElse { assert(expr.resolved, "column should be resolved to use the same type " + - "as input. Try df(name) or df.col(name)") + "as input. Try df(name) or df.col(name)") expr.dataType } - Column(Cast(toScalaUDF(this, Cast(expr, StringType) :: Nil), rt)) + Cast(toScalaUDF(udf, Cast(expr, StringType) :: Nil), rt) } - override def withName(name: String): TestInternalScalaUDF = { - // "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object - // is sliced and the overridden "apply" is not invoked. - new TestInternalScalaUDF(name) - } - } - - case class TestScalaUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name, returnType) - - def apply(exprs: Column*): Column = udf(exprs: _*) + def apply(exprs: Column*): Column = Column(builder(exprs.map(_.expr))) val prettyName: String = "Scala UDF" } @@ -1612,7 +1603,9 @@ object IntegratedUDFTestUtils extends SQLHelper { case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf) - case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf) + case udf: TestScalaUDF => + val registry = session.sessionState.functionRegistry + registry.createOrReplaceTempFunction(udf.name, udf.builder, "scala_udf") case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7e940252430f8..36552d5c5487c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -775,7 +775,8 @@ class UDFSuite extends QueryTest with SharedSparkSession { errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`b`", - "proposal" -> "`a`")) + "proposal" -> "`a`"), + context = ExpectedContext("apply", ".*")) } test("wrong order of input fields for case class") { From 3e41a98927cc1034689dc766188b9d6d5b0bd52f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 13 Aug 2024 12:06:31 -0400 Subject: [PATCH 09/18] Add test for ColumnNode sql and normalize --- .../spark/sql/internal/columnNodes.scala | 24 +- .../spark/sql/internal/ColumnNodeSuite.scala | 234 ++++++++++++++++++ 2 files changed, 251 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index d07dc61307ad8..e92527ec38c06 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import ColumnNode._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata} import org.apache.spark.util.SparkClassUtils @@ -98,12 +99,16 @@ private[internal] object ColumnNode { private[sql] case class Literal( value: Any, dataType: Option[DataType] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) extends ColumnNode with DataTypeErrorsBase { override private[internal] def normalize(): Literal = copy(origin = NO_ORIGIN) - // TODO make this nicer. override def sql: String = value match { case null => "NULL" + case v: String => toSQLValue(v) + case v: Long => toSQLValue(v) + case v: Float => toSQLValue(v) + case v: Double => toSQLValue(v) + case v: Short => toSQLValue(v) case _ => value.toString } } @@ -302,7 +307,7 @@ private[sql] case class WindowSpec( override private[internal] def sql: String = { val parts = Seq( elementsToSql(partitionColumns, "PARTITION BY "), - elementsToSql(sortColumns, "ORDER BY"), + elementsToSql(sortColumns, "ORDER BY "), optionToSql(frame)) parts.filter(_.nonEmpty).mkString(" ") } @@ -360,13 +365,18 @@ private[sql] case class LambdaFunction( arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN) - override def sql: String = argumentsToSql(arguments) + " -> " + function.sql + override def sql: String = { + val argumentsSql = arguments match { + case Seq(arg) => arg.sql + case _ => argumentsToSql(arguments) + } + argumentsSql + " -> " + function.sql + } } object LambdaFunction { - def apply(function: ColumnNode, arguments: Seq[UnresolvedNamedLambdaVariable]): LambdaFunction = ( + def apply(function: ColumnNode, arguments: Seq[UnresolvedNamedLambdaVariable]): LambdaFunction = new LambdaFunction(function, arguments, CurrentOrigin.get) - ) } /** @@ -405,7 +415,7 @@ private[sql] case class UnresolvedExtractValue( override val origin: Origin = CurrentOrigin.get) extends ColumnNode { override private[internal] def normalize(): UnresolvedExtractValue = copy( child = child.normalize(), - extraction = child.normalize(), + extraction = extraction.normalize(), origin = NO_ORIGIN) override def sql: String = s"${child.sql}[${extraction.sql}]" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala new file mode 100644 index 0000000000000..052f220d97a50 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{functions, Dataset} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.{IntegerType, LongType, Metadata, MetadataBuilder, StringType} + +class ColumnNodeSuite extends SparkFunSuite { + private val simpleUdf = functions.udf((i: Int) => i + 1) + + test("sql") { + testSql(Literal(null), "NULL") + testSql(Literal(10), "10") + testSql(Literal(33L), "33L") + testSql(Literal(55.toShort), "55S") + testSql(Literal(3.14), "3.14") + testSql(Literal(Double.NaN), "NaN") + testSql(Literal(Double.NegativeInfinity), "-Infinity") + testSql(Literal(Double.PositiveInfinity), "Infinity") + testSql(Literal(3.9f), "3.9") + testSql(Literal(Float.NaN), "NaN") + testSql(Literal(Float.NegativeInfinity), "-Infinity") + testSql(Literal(Float.PositiveInfinity), "Infinity") + testSql(Literal("hello"), "'hello'") + testSql(Literal("\\_'"), "'\\\\_\\''") + testSql(Literal((1, 2)), "(1,2)") + testSql(UnresolvedStar(None), "*") + testSql(UnresolvedStar(Option("prefix")), "prefix.*") + testSql(UnresolvedAttribute("a"), "a") + testSql(SqlExpression("1 + 1"), "1 + 1") + testSql(Alias(UnresolvedAttribute("b"), Seq("new_b")), "b AS new_b") + testSql(Alias(UnresolvedAttribute("c"), Seq("x", "y", "z")), "c AS (x, y, z)") + testSql(Cast(UnresolvedAttribute("c"), IntegerType), "CAST(c AS INT)") + testSql(Cast(UnresolvedAttribute("d"), StringType, Option(Cast.Try)), "TRY_CAST(d AS STRING)") + testSql( + SortOrder(attribute("e"), SortOrder.Ascending, SortOrder.NullsLast), + "e ASC NULLS LAST") + testSql( + SortOrder(attribute("f"), SortOrder.Ascending, SortOrder.NullsFirst), + "f ASC NULLS FIRST") + testSql( + SortOrder(attribute("g"), SortOrder.Descending, SortOrder.NullsLast), + "g DESC NULLS LAST") + testSql( + SortOrder(attribute("h"), SortOrder.Descending, SortOrder.NullsFirst), + "h DESC NULLS FIRST") + testSql( + UnresolvedFunction("coalesce", Seq(Literal(null), UnresolvedAttribute("i"))), + "coalesce(NULL, i)") + val lambdaVariableX = new UnresolvedNamedLambdaVariable("x") + val lambdaVariableY = new UnresolvedNamedLambdaVariable("y") + testSql( + UnresolvedFunction( + "transform", Seq( + UnresolvedAttribute("input"), + LambdaFunction( + UnresolvedFunction("adjust", Seq(lambdaVariableX, UnresolvedAttribute("b"))), + Seq(lambdaVariableX)))), + "transform(input, x -> adjust(x, b))") + testSql( + UnresolvedFunction( + "transform", Seq( + UnresolvedAttribute("input"), + LambdaFunction( + UnresolvedFunction( + "combine", + Seq(lambdaVariableX, lambdaVariableY, UnresolvedAttribute("b"))), + Seq(lambdaVariableX, lambdaVariableY)))), + "transform(input, (x, y) -> combine(x, y, b))") + testSql(UnresolvedExtractValue(attribute("a", 2), attribute("b", 3)), "a[b]") + testSql(UpdateFields(UnresolvedAttribute("struct"), "a"), "drop_field(struct, a)") + testSql( + UpdateFields(UnresolvedAttribute("struct"), "b", Option(Literal(10.toLong))), + "update_field(struct, b, 10L)") + testSql(CaseWhenOtherwise( + Seq(UnresolvedAttribute("c1") -> UnresolvedAttribute("v1"))), + "CASE WHEN c1 THEN v1 END") + testSql(CaseWhenOtherwise( + Seq( + UnresolvedAttribute("c1") -> UnresolvedAttribute("v1"), + UnresolvedAttribute("c2") -> UnresolvedAttribute("v2")), + Option(Literal(25))), + "CASE WHEN c1 THEN v1 WHEN c2 THEN v2 ELSE 25 END") + val windowSpec = WindowSpec( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + Seq( + SortOrder(attribute("x"), SortOrder.Descending, SortOrder.NullsFirst), + SortOrder(attribute("y"), SortOrder.Ascending, SortOrder.NullsFirst))) + val reducedWindowSpec = windowSpec.copy( + partitionColumns = windowSpec.partitionColumns.take(1), + sortColumns = windowSpec.sortColumns.take(1)) + val window = Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("i"))), + WindowSpec(Nil, Nil)) + testSql(window, "sum(i) OVER ()") + testSql( + window.copy(windowSpec = windowSpec.copy(sortColumns = Nil)), + "sum(i) OVER (PARTITION BY a, b)") + testSql( + window.copy(windowSpec = windowSpec.copy(partitionColumns = Nil)), + "sum(i) OVER (ORDER BY x DESC NULLS FIRST, y ASC NULLS FIRST)") + testSql( + window.copy(windowSpec = windowSpec), + "sum(i) OVER (PARTITION BY a, b ORDER BY x DESC NULLS FIRST, y ASC NULLS FIRST)") + testSql( + window.copy(windowSpec = reducedWindowSpec), + "sum(i) OVER (PARTITION BY a ORDER BY x DESC NULLS FIRST)") + testSql( + window.copy(windowSpec = reducedWindowSpec.copy(frame = Option(WindowFrame( + WindowFrame.Row, + WindowFrame.UnboundedPreceding, + WindowFrame.CurrentRow)))), + "sum(i) OVER (PARTITION BY a ORDER BY x DESC NULLS FIRST " + + "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)") + testSql( + window.copy(windowSpec = reducedWindowSpec.copy(frame = Option(WindowFrame( + WindowFrame.Range, + WindowFrame.Value(Literal(-10)), + WindowFrame.UnboundedFollowing)))), + "sum(i) OVER (PARTITION BY a ORDER BY x DESC NULLS FIRST " + + "RANGE BETWEEN -10 AND UNBOUNDED FOLLOWING)") + testSql(InvokeInlineUserDefinedFunction(simpleUdf, Seq(UnresolvedAttribute("x"))), "UDF(x)") + testSql( + InvokeInlineUserDefinedFunction(simpleUdf.withName("smple"), Seq(UnresolvedAttribute("x"))), + "smple(x)") + } + + private def testSql(node: ColumnNode, expectedSql: String): Unit = { + assert(node.sql == expectedSql) + } + + test("normalization") { + testNormalization(Literal(1)) + testNormalization(UnresolvedStar(Option("a.b"), planId = planId())) + testNormalization(UnresolvedAttribute("x", planId = planId())) + testNormalization(UnresolvedRegex(".*", planId = planId())) + testNormalization(SqlExpression("1 + 1")) + testNormalization(attribute("a")) + testNormalization(Alias(attribute("a"), Seq("aa"), None)) + testNormalization(Cast(attribute("b"), IntegerType, Some(Cast.Try))) + testNormalization(SortOrder(attribute("c"), SortOrder.Ascending, SortOrder.NullsLast)) + val lambdaVariable = UnresolvedNamedLambdaVariable("x") + testNormalization( + UnresolvedFunction( + "transform", Seq( + attribute("input", 331), + LambdaFunction( + UnresolvedFunction("adjust", Seq(lambdaVariable, attribute("b", 2))), + Seq(lambdaVariable))))) + testNormalization(UnresolvedExtractValue(attribute("b", 2), attribute("a", 8))) + testNormalization(UpdateFields(attribute("struct", 4), "a", Option(attribute("a", 11)))) + testNormalization(CaseWhenOtherwise( + Seq( + attribute("c1", 5) -> attribute("v1", 2), + attribute("c2", 3) -> attribute("v2", 4)), + Option(attribute("v2", 5)))) + testNormalization(Window( + UnresolvedFunction("sum", Seq(attribute("a")), isInternal = true, isDistinct = true), + WindowSpec( + Seq(attribute("b", 2)), + Seq(SortOrder(attribute("c", 3), SortOrder.Descending, SortOrder.NullsFirst)), + // Not a supported frame, just here for testing. + Option(WindowFrame( + WindowFrame.Range, + WindowFrame.Value(attribute("d", 3)), + WindowFrame.Value(attribute("e", 4))))))) + testNormalization(InvokeInlineUserDefinedFunction( + simpleUdf, + Seq(attribute("a", 2)))) + } + + private def testNormalization(generate: => ColumnNode): Unit = { + val a = CurrentOrigin.withOrigin(origin())(generate) + val b = CurrentOrigin.withOrigin(origin())(generate) + val c = try { + createNormalized.set(true) + CurrentOrigin.withOrigin(ColumnNode.NO_ORIGIN)(generate) + } finally { + createNormalized.set(false) + } + assert(a != a.normalized) + assert(a.normalized eq a.normalized.normalized) + assert(a != b) + assert(a.normalized == b.normalized) + assert(a.normalized == c) + } + + private val createNormalized: ThreadLocal[Boolean] = new ThreadLocal[Boolean] { + override def initialValue(): Boolean = false + } + + private val idGenerator = new AtomicInteger() + private def nextId: Int = idGenerator.incrementAndGet() + + private def origin(): Origin = Origin(line = Option(nextId)) + + private def planId(): Option[Long] = { + if (!createNormalized.get()) { + Some(nextId.toLong) + } else { + None + } + } + + private def attribute(name: String, id: Long = 1): ColumnNode = { + val metadata = if (!createNormalized.get()) { + new MetadataBuilder() + .putLong(Dataset.DATASET_ID_KEY, nextId) + .build() + } else { + Metadata.empty + } + Wrapper(AttributeReference(name, LongType, metadata = metadata)(exprId = ExprId(id))) + } +} From 763a082d79d0ef71555fd626f7a2107b6dbefde3 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 14 Aug 2024 16:16:51 -0400 Subject: [PATCH 10/18] Fix pyspark issues --- python/pyspark/pandas/internal.py | 6 ++---- python/pyspark/sql/functions/builtin.py | 16 ++++++-------- python/pyspark/sql/tests/test_dataframe.py | 1 - python/pyspark/sql/udtf.py | 6 +----- .../sql/catalyst/analysis/Analyzer.scala | 8 ++++++- .../spark/sql/api/python/PythonSQLUtils.scala | 21 +++++++++++++++++-- .../org/apache/spark/sql/functions.scala | 12 ++--------- 7 files changed, 37 insertions(+), 33 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index c5fef3b138254..92d4a3357319f 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -915,10 +915,8 @@ def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySpar if is_remote(): return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols) jvm = sdf.sparkSession._jvm - tag = jvm.org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS() - jexpr = F.monotonically_increasing_id()._jc.expr() - jexpr.setTagValue(tag, "distributed_index") - return sdf.select(PySparkColumn(jvm.Column(jexpr)).alias(column_name), *scols) + jcol = jvm.PythonSQLUtils.distributedIndex() + return sdf.select(PySparkColumn(jcol).alias(column_name), *scols) @staticmethod def attach_distributed_sequence_column( diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index dee1a2144a91f..0fae95a4b0143 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -18574,10 +18574,7 @@ def _unresolved_named_lambda_variable(name: str) -> Column: from py4j.java_gateway import JVMView sc = _get_active_spark_context() - internal = cast(JVMView, sc._jvm).org.apache.spark.sql.internal - return Column( - cast(JVMView, sc._jvm).Column(internal.UnresolvedNamedLambdaVariable.apply(name)) - ) + return Column(cast(JVMView, sc._jvm).PythonSQLUtils.unresolvedNamedLambdaVariable(name)) def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]: @@ -18642,10 +18639,9 @@ def _create_lambda(f: Callable) -> Callable: messageParameters={"func_name": f.__name__, "return_type": type(result).__name__}, ) - jexpr = result._jc.node() - jargs = _to_seq(sc, [arg._jc.node() for arg in args]) - - return cast(JVMView, sc._jvm).Column(internal.LambdaFunction.apply(jexpr, jargs)) + jexpr = result._jc + jargs = _to_seq(sc, [arg._jc for arg in args]) + return cast(JVMView, sc._jvm).PythonSQLUtils.lambdaFunction(jexpr, jargs) def _invoke_higher_order_function( @@ -18962,10 +18958,10 @@ def aggregate( +----+ """ if finish is not None: - return _invoke_higher_order_function("array_agg", [col, initialValue], [merge, finish]) + return _invoke_higher_order_function("aggregate", [col, initialValue], [merge, finish]) else: - return _invoke_higher_order_function("array_agg", [col, initialValue], [merge]) + return _invoke_higher_order_function("aggregate", [col, initialValue], [merge]) @_try_remote_functions diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 2feba0b3b345f..7dd42eecde7f8 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -114,7 +114,6 @@ def test_count_star(self): self.assertEqual(df3.select(count(df3["*"])).columns, ["count(1)"]) self.assertEqual(df3.select(count(col("*"))).columns, ["count(1)"]) - self.assertEqual(df3.select(count(col("s.*"))).columns, ["count(1)"]) def test_self_join(self): df1 = self.spark.range(10).withColumn("a", lit(0)) diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index f56b8358699d3..b3a9f5f5be992 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -382,11 +382,7 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFram assert sc._jvm is not None jcols = [_to_java_column(arg) for arg in args] + [ - sc._jvm.Column( - sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression( - key, _to_java_expr(value) - ) - ) + sc._jvm.PythonSQLUtils.namedArgumentExpression(key, _to_java_column(value)) for key, value in kwargs.items() ] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c98425ce02fa2..0975203edef67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1930,6 +1930,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def extractStar(exprs: Seq[Expression]): Seq[Star] = exprs.flatMap(_.collect { case s: Star => s }) + private def isCountStarExpansionAllowed(arguments: Seq[Expression]): Boolean = arguments match { + case Seq(UnresolvedStar(None)) => true + case Seq(_: ResolvedStar) => true + case _ => false + } + /** * Expands the matching attribute.*'s in `child`'s output. */ @@ -1937,7 +1943,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor expr.transformUp { case f0: UnresolvedFunction if !f0.isDistinct && f0.nameParts.map(_.toLowerCase(Locale.ROOT)) == Seq("count") && - f0.arguments == Seq(UnresolvedStar(None)) => + isCountStarExpansionAllowed(f0.arguments) => // Transform COUNT(*) into COUNT(1). f0.copy(nameParts = Seq("count"), arguments = Seq(Literal(1))) case f1: UnresolvedFunction if containsStar(f1.arguments) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 84418a0ecc65f..45ac2c73703a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer -import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} +import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, Wrapper} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -174,6 +174,23 @@ private[sql] object PythonSQLUtils extends Logging { def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = { Column(PandasCovar(col1.expr, col2.expr, ddof).toAggregateExpression(false)) } + + def unresolvedNamedLambdaVariable(name: String): Column = + Column(internal.UnresolvedNamedLambdaVariable.apply(name)) + + def lambdaFunction(function: Column, variables: Seq[Column]): Column = { + val arguments = variables.map(_.node.asInstanceOf[internal.UnresolvedNamedLambdaVariable]) + Column(internal.LambdaFunction(function.node, arguments)) + } + + def namedArgumentExpression(name: String, e: Column): Column = + Column(Wrapper(NamedArgumentExpression(name, e.expr))) + + def distributedIndex(): Column = { + val expr = MonotonicallyIncreasingID() + expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") + Column(Wrapper(expr)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 52cec7ff2e0f3..63913993050d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -403,22 +403,14 @@ object functions { corr(Column(columnName1), Column(columnName2)) } - private val ONE = Column(internal.Literal(1, Option(IntegerType))) - /** * Aggregate function: returns the number of items in a group. * * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = { - val withoutStar = e.node match { - // Turn count(*) into count(1) - case internal.UnresolvedStar(None, _, _) => ONE - case _ => e - } - Column.fn("count", withoutStar) - } + def count(e: Column): Column = + Column.fn("count", e) /** * Aggregate function: returns the number of items in a group. From 4244ef644c82b24baba0390d696c738d95d59fbc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Aug 2024 00:41:13 -0400 Subject: [PATCH 11/18] fixes --- R/pkg/R/functions.R | 75 ++++++++----------- python/pyspark/sql/functions/builtin.py | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../scala/org/apache/spark/sql/Column.scala | 8 +- .../spark/sql/api/python/PythonSQLUtils.scala | 6 +- 5 files changed, 41 insertions(+), 56 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b91124f96a6fa..9c825a99be180 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3965,19 +3965,11 @@ setMethod("row_number", #' yields unresolved \code{a.b.c} #' @return Column object wrapping JVM UnresolvedNamedLambdaVariable #' @keywords internal -unresolved_named_lambda_var <- function(...) { - jc <- newJObject( - "org.apache.spark.sql.Column", - newJObject( - "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", - lapply(list(...), function(x) { - handledCallJStatic( - "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", - "freshVarName", - x) - }) - ) - ) +unresolved_named_lambda_var <- function(name) { + jc <- handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "unresolvedNamedLambdaVariable", + name) column(jc) } @@ -3990,7 +3982,6 @@ unresolved_named_lambda_var <- function(...) { #' @return JVM \code{LambdaFunction} object #' @keywords internal create_lambda <- function(fun) { - as_jexpr <- function(x) callJMethod(x@jc, "expr") # Process function arguments parameters <- formals(fun) @@ -4011,22 +4002,18 @@ create_lambda <- function(fun) { stopifnot(class(result) == "Column") # Convert both Columns to Scala expressions - jexpr <- as_jexpr(result) - jargs <- handledCallJStatic( "org.apache.spark.api.python.PythonUtils", "toSeq", - handledCallJStatic( - "java.util.Arrays", "asList", lapply(args, as_jexpr) - ) + handledCallJStatic("java.util.Arrays", "asList", lapply(args, function(x) { x@jc })) ) # Create Scala LambdaFunction - newJObject( - "org.apache.spark.sql.catalyst.expressions.LambdaFunction", - jexpr, - jargs, - FALSE + handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "lambdaFunction", + result@jc, + jargs ) } @@ -4039,20 +4026,18 @@ create_lambda <- function(fun) { #' @return a \code{Column} representing name applied to cols with funs #' @keywords internal invoke_higher_order_function <- function(name, cols, funs) { - as_jexpr <- function(x) { + as_col <- function(x) { if (class(x) == "character") { x <- column(x) } - callJMethod(x@jc, "expr") + x@jc } - - jexpr <- do.call(newJObject, c( - paste("org.apache.spark.sql.catalyst.expressions", name, sep = "."), - lapply(cols, as_jexpr), - lapply(funs, create_lambda) - )) - - column(newJObject("org.apache.spark.sql.Column", jexpr)) + jcol <- handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "fn", + name, + c(lapply(cols, as_col), lapply(funs, create_lambda))) # check varargs invocation + column(jcol) } #' @details @@ -4068,7 +4053,7 @@ setMethod("array_aggregate", signature(x = "characterOrColumn", initialValue = "Column", merge = "function"), function(x, initialValue, merge, finish = NULL) { invoke_higher_order_function( - "ArrayAggregate", + "aggregate", cols = list(x, initialValue), funs = if (is.null(finish)) { list(merge) @@ -4129,7 +4114,7 @@ setMethod("array_exists", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayExists", + "exists", cols = list(x), funs = list(f) ) @@ -4145,7 +4130,7 @@ setMethod("array_filter", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayFilter", + "filter", cols = list(x), funs = list(f) ) @@ -4161,7 +4146,7 @@ setMethod("array_forall", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayForAll", + "forall", cols = list(x), funs = list(f) ) @@ -4291,7 +4276,7 @@ setMethod("array_sort", column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc)) } else { invoke_higher_order_function( - "ArraySort", + "array_sort", cols = list(x), funs = list(comparator) ) @@ -4309,7 +4294,7 @@ setMethod("array_transform", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayTransform", + "transform", cols = list(x), funs = list(f) ) @@ -4374,7 +4359,7 @@ setMethod("arrays_zip_with", signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), function(x, y, f) { invoke_higher_order_function( - "ZipWith", + "zip_with", cols = list(x, y), funs = list(f) ) @@ -4447,7 +4432,7 @@ setMethod("map_filter", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "MapFilter", + "map_filter", cols = list(x), funs = list(f)) }) @@ -4504,7 +4489,7 @@ setMethod("transform_keys", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "TransformKeys", + "transform_keys", cols = list(x), funs = list(f) ) @@ -4521,7 +4506,7 @@ setMethod("transform_values", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "TransformValues", + "transform_values", cols = list(x), funs = list(f) ) @@ -4552,7 +4537,7 @@ setMethod("map_zip_with", signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), function(x, y, f) { invoke_higher_order_function( - "MapZipWith", + "map_zip_with", cols = list(x, y), funs = list(f) ) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0fae95a4b0143..d401e1d758ed9 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -18665,7 +18665,7 @@ def _invoke_higher_order_function( sc = _get_active_spark_context() jfuns = [_create_lambda(f) for f in funs] jcols = [_to_java_column(c) for c in cols] - return Column(sc._jvm.Column.pysparkFn(name, _to_seq(sc, jcols + jfuns))) + return Column(sc._jvm.PythonSQLUtils.fn(name, _to_seq(sc, jcols + jfuns))) @overload def transform(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0975203edef67..4c22fd7490ff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1583,6 +1583,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } + case c: CollectMetrics if containsStar(c.metrics) => + c.copy(metrics = buildExpandedProjectList(c.metrics, c.child)) case g: Generate if containsStar(g.generator.children) => throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF", extractStar(g.generator.children)) @@ -1891,8 +1893,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } catch { case e: AnalysisException => AnalysisContext.get.outerPlan.map { - // Only Project and Aggregate can host star expressions. - case u @ (_: Project | _: Aggregate) => + // Only Project, Aggregate, CollectMetrics can host star expressions. + case u @ (_: Project | _: Aggregate | _: CollectMetrics) => Try(s.expand(u.children.head, resolver)) match { case Success(expanded) => expanded.map(wrapOuterReference) case Failure(_) => throw e diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 22931978bc122..cf64de3686a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ -import org.apache.spark.annotation.{Private, Stable} +import org.apache.spark.annotation.Stable import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{LEFT_EXPR, RIGHT_EXPR} import org.apache.spark.sql.catalyst.analysis._ @@ -74,12 +74,6 @@ private[spark] object Column { isDistinct = isDistinct, isInternal = isInternal)) } - - /** - * Hook used by pyspark to create functions. - */ - @Private - def pysparkFn(name: String, args: Seq[Column]): Column = fn(name, args: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 45ac2c73703a6..36a6b1f7d3001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -178,7 +178,8 @@ private[sql] object PythonSQLUtils extends Logging { def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) - def lambdaFunction(function: Column, variables: Seq[Column]): Column = { + @scala.annotation.varargs + def lambdaFunction(function: Column, variables: Column*): Column = { val arguments = variables.map(_.node.asInstanceOf[internal.UnresolvedNamedLambdaVariable]) Column(internal.LambdaFunction(function.node, arguments)) } @@ -191,6 +192,9 @@ private[sql] object PythonSQLUtils extends Logging { expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") Column(Wrapper(expr)) } + + @scala.annotation.varargs + def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) } /** From e573f7c58da63dfcb2a012fdcec8334003f7d0e7 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Aug 2024 10:37:18 -0400 Subject: [PATCH 12/18] Fix Connect MiMa --- .../connect/client/CheckConnectJvmClientCompatibility.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 59b399da9a5c6..07c9e5190da00 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -300,6 +300,12 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.register"), + // Typed Column + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.*"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.sql.TypedColumn.expr"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"), + // Datasource V2 partition transforms ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"), From 4ba1d94ffc762dabafe824cab0fc9734b1c5b563 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Aug 2024 11:12:05 -0400 Subject: [PATCH 13/18] Fix docs --- .../apache/spark/sql/expressions/UserDefinedFunction.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 77cfff37bf5a1..3d4e9d9a2c203 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -186,10 +186,9 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( private[sql] object UserDefinedFunctionUtils { /** - * Create a [[ScalaUDF]]. + * Convert a UDF into an (executable) ScalaUDF expressions. * - * This function should be moved to [[ScalaUDF]] when we move [[SparkUserDefinedFunction]] - * to sql/api. + * This function should be moved to ScalaUDF when we move SparkUserDefinedFunction to sql/api. */ def toScalaUDF(udf: SparkUserDefinedFunction, children: Seq[Expression]): ScalaUDF = { ScalaUDF( From 7318f60dac48145055c7cfd7ad59b7cfccc8259f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Aug 2024 14:04:37 -0400 Subject: [PATCH 14/18] style --- python/pyspark/sql/functions/builtin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index d401e1d758ed9..285b342eee42f 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -18626,10 +18626,7 @@ def _create_lambda(f: Callable) -> Callable: internal = cast(JVMView, sc._jvm).org.apache.spark.sql.internal argnames = ["x", "y", "z"] - args = [ - _unresolved_named_lambda_variable(arg) - for arg in argnames[: len(parameters)] - ] + args = [_unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)]] result = f(*args) @@ -18667,6 +18664,7 @@ def _invoke_higher_order_function( jcols = [_to_java_column(c) for c in cols] return Column(sc._jvm.PythonSQLUtils.fn(name, _to_seq(sc, jcols + jfuns))) + @overload def transform(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: ... From c73ef8eeb6c2082663d7d63c8506905d5f32eaf8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Aug 2024 14:08:35 -0400 Subject: [PATCH 15/18] merge artifact --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 63913993050d6..be83444a8fd33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -7904,12 +7904,7 @@ object functions { | * @since 2.3.0 | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { -<<<<<<< HEAD - | val func = $funcCall - | SparkUserDefinedFunction(func, returnType, $i) -======= | SparkUserDefinedFunction(ToScalaUDF(f), returnType, $i) ->>>>>>> apache/master |}""".stripMargin) } From e365d5a7a6da2d6999d31643b0f5ba64c5aaea62 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Aug 2024 23:32:11 -0400 Subject: [PATCH 16/18] style --- python/pyspark/sql/functions/builtin.py | 1 - python/pyspark/sql/udtf.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 285b342eee42f..66ccac17044c3 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -18623,7 +18623,6 @@ def _create_lambda(f: Callable) -> Callable: parameters = _get_lambda_parameters(f) sc = _get_active_spark_context() - internal = cast(JVMView, sc._jvm).org.apache.spark.sql.internal argnames = ["x", "y", "z"] args = [_unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)]] diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index b3a9f5f5be992..5ce3e2dfd2a9e 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -373,7 +373,7 @@ def _create_judtf(self, func: Type) -> "JavaObject": return judtf def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame": - from pyspark.sql.classic.column import _to_java_column, _to_java_expr, _to_seq + from pyspark.sql.classic.column import _to_java_column, _to_seq from pyspark.sql import DataFrame, SparkSession From b818d89b5f1e5888c47b8a3ccdc6af38831c62d8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 16 Aug 2024 09:44:43 -0400 Subject: [PATCH 17/18] python typing --- python/pyspark/sql/functions/builtin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 66ccac17044c3..24b8ae82e99ad 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -18656,12 +18656,13 @@ def _invoke_higher_order_function( :return: a Column """ + from py4j.java_gateway import JVMView from pyspark.sql.classic.column import _to_seq, _to_java_column sc = _get_active_spark_context() jfuns = [_create_lambda(f) for f in funs] jcols = [_to_java_column(c) for c in cols] - return Column(sc._jvm.PythonSQLUtils.fn(name, _to_seq(sc, jcols + jfuns))) + return Column(cast(JVMView, sc._jvm).PythonSQLUtils.fn(name, _to_seq(sc, jcols + jfuns))) @overload From b4f9608695eb07954bbe26119a2877fb0c58d933 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 16 Aug 2024 16:42:27 -0400 Subject: [PATCH 18/18] Code Review --- .../connect/planner/SparkConnectPlanner.scala | 4 +- .../scala/org/apache/spark/sql/Column.scala | 6 +-- .../apache/spark/sql/UDFRegistration.scala | 6 +-- .../spark/sql/api/python/PythonSQLUtils.scala | 6 +-- .../sql/expressions/UserDefinedFunction.scala | 25 ----------- .../internal/UserDefinedFunctionUtils.scala | 44 +++++++++++++++++++ .../sql/internal/columnNodeSupport.scala | 9 ++-- .../sql/DataFrameWindowFramesSuite.scala | 6 +-- .../spark/sql/IntegratedUDFTestUtils.scala | 2 +- .../spark/sql/internal/ColumnNodeSuite.scala | 6 ++- ...ColumnNodeToExpressionConverterSuite.scala | 2 +- 11 files changed, 70 insertions(+), 46 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6b9136cf18a35..c8aba5d19fe7f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -77,8 +77,8 @@ import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPy import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction, UserDefinedFunctionUtils} -import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} +import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils, UserDefinedFunctionUtils} import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index cf64de3686a53..26df8cd9294b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{ColumnNode, TypedAggUtils, Wrapper} +import org.apache.spark.sql.internal.{ColumnNode, ExpressionColumnNode, TypedAggUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -39,7 +39,7 @@ private[spark] object Column { def apply(colName: String): Column = new Column(colName) - def apply(expr: Expression): Column = Column(Wrapper(expr)) + def apply(expr: Expression): Column = Column(ExpressionColumnNode(expr)) def apply(node: => ColumnNode): Column = withOrigin(new Column(node)) @@ -288,7 +288,7 @@ class Column(val node: ColumnNode) extends Logging { if (this == right) { logWarning( log"Constructing trivially true equals predicate, " + - log"'${MDC(LEFT_EXPR, this)} <=> ${MDC(RIGHT_EXPR, right)}'. " + + log"'${MDC(LEFT_EXPR, this)} == ${MDC(RIGHT_EXPR, right)}'. " + log"Perhaps you need to use aliases.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dacffaad3cd09..4fdb84836b0fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -33,8 +33,8 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.internal.ToScalaUDF +import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends println(s""" |/** | * Register a deterministic Java UDF$i instance as user-defined function (UDF). - | * @since 1.3.0 + | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | registerJavaUDF(name, ToScalaUDF(f), returnType, $i) @@ -493,7 +493,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /** * Register a deterministic Java UDF0 instance as user-defined function (UDF). - * @since 1.3.0 + * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 36a6b1f7d3001..dbb3a333bfb11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.internal.{SQLConf, Wrapper} +import org.apache.spark.sql.internal.{ExpressionColumnNode, SQLConf} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -185,12 +185,12 @@ private[sql] object PythonSQLUtils extends Logging { } def namedArgumentExpression(name: String, e: Column): Column = - Column(Wrapper(NamedArgumentExpression(name, e.expr))) + Column(ExpressionColumnNode(NamedArgumentExpression(name, e.expr))) def distributedIndex(): Column = { val expr = MonotonicallyIncreasingID() expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") - Column(Wrapper(expr)) + Column(ExpressionColumnNode(expr)) } @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 3d4e9d9a2c203..403eccfddffad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -23,9 +23,6 @@ import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder -import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} import org.apache.spark.sql.types.DataType @@ -183,25 +180,3 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( override def name: String = givenName.getOrElse(aggregator.name) } - -private[sql] object UserDefinedFunctionUtils { - /** - * Convert a UDF into an (executable) ScalaUDF expressions. - * - * This function should be moved to ScalaUDF when we move SparkUserDefinedFunction to sql/api. - */ - def toScalaUDF(udf: SparkUserDefinedFunction, children: Seq[Expression]): ScalaUDF = { - ScalaUDF( - udf.f, - udf.dataType, - children, - udf.inputEncoders.map(_.collect { - // At some point it would be nice if were to support this. - case e if e != UnboundRowEncoder => encoderFor(e) - }), - udf.outputEncoder.map(encoderFor(_)), - udfName = udf.givenName, - nullable = udf.nullable, - udfDeterministic = udf.deterministic) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala new file mode 100644 index 0000000000000..bd8735d15be13 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.expressions.SparkUserDefinedFunction + +private[sql] object UserDefinedFunctionUtils { + /** + * Convert a UDF into an (executable) ScalaUDF expressions. + * + * This function should be moved to ScalaUDF when we move SparkUserDefinedFunction to sql/api. + */ + def toScalaUDF(udf: SparkUserDefinedFunction, children: Seq[Expression]): ScalaUDF = { + ScalaUDF( + udf.f, + udf.dataType, + children, + udf.inputEncoders.map(_.collect { + // At some point it would be nice if were to support this. + case e if e != UnboundRowEncoder => encoderFor(e) + }), + udf.outputEncoder.map(encoderFor(_)), + udfName = udf.givenName, + nullable = udf.nullable, + udfDeterministic = udf.deterministic) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 9da75716cb630..ea6e36680da45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.internal +import UserDefinedFunctionUtils.toScalaUDF + import org.apache.spark.SparkException import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} @@ -28,7 +30,6 @@ import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, TypedAggregateExpression} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator} -import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF /** * Convert a [[ColumnNode]] into an [[Expression]]. @@ -171,7 +172,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case InvokeInlineUserDefinedFunction(udf: SparkUserDefinedFunction, arguments, _, _) => toScalaUDF(udf, arguments.map(apply)) - case Wrapper(expression, _) => + case ExpressionColumnNode(expression, _) => expression case node => @@ -233,10 +234,10 @@ private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressi /** * [[ColumnNode]] wrapper for an [[Expression]]. */ -private[sql] case class Wrapper( +private[sql] case class ExpressionColumnNode( expression: Expression, override val origin: Origin = CurrentOrigin.get) extends ColumnNode { - override def normalize(): Wrapper = { + override def normalize(): ExpressionColumnNode = { val updated = expression.transform { case a: AttributeReference => DetectAmbiguousSelfJoin.stripColumnReferenceMetadata(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 4923ef9c4ebb9..c03c5e878427f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.{SQLConf, Wrapper} +import org.apache.spark.sql.internal.{ExpressionColumnNode, SQLConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.CalendarIntervalType @@ -506,7 +506,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { val windowSpec = Window.partitionBy($"value").orderBy($"key".asc).withFrame( internal.WindowFrame.Range, - internal.WindowFrame.Value(Wrapper(Literal.create(null, CalendarIntervalType))), + internal.WindowFrame.Value(ExpressionColumnNode(Literal.create(null, CalendarIntervalType))), internal.WindowFrame.Value(lit(2).node)) checkError( exception = intercept[AnalysisException] { @@ -528,7 +528,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { val df = Seq((1L, "1"), (1L, "1")).toDF("key", "value") val windowSpec = Window.partitionBy($"value").orderBy($"key".asc).withFrame( internal.WindowFrame.Range, - internal.WindowFrame.Value(Wrapper(NonFoldableLiteral(1))), + internal.WindowFrame.Value(ExpressionColumnNode(NonFoldableLiteral(1))), internal.WindowFrame.Value(lit(2).node)) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 070f46dc74f01..44709fd309cfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.expressions.UserDefinedFunctionUtils.toScalaUDF +import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType, VariantType} import org.apache.spark.util.ArrayImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala index 052f220d97a50..7bf70695a9854 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala @@ -229,6 +229,10 @@ class ColumnNodeSuite extends SparkFunSuite { } else { Metadata.empty } - Wrapper(AttributeReference(name, LongType, metadata = metadata)(exprId = ExprId(id))) + ExpressionColumnNode(AttributeReference( + name, + LongType, + metadata = metadata)( + exprId = ExprId(id))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index 682408c2239f8..0fbfe762df918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -387,7 +387,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("extension") { testConversion( - Wrapper(analysis.UnresolvedAttribute("bar")), + ExpressionColumnNode(analysis.UnresolvedAttribute("bar")), analysis.UnresolvedAttribute("bar")) }