From ebc24a9b7fde273ee4912f9bc1c5059703f7b31e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 25 Jul 2017 17:19:44 -0700 Subject: [PATCH] [SPARK-20586][SQL] Add deterministic to ScalaUDF ### What changes were proposed in this pull request? Like [Hive UDFType](https://hive.apache.org/javadocs/r2.0.1/api/org/apache/hadoop/hive/ql/udf/UDFType.html), we should allow users to add the extra flags for ScalaUDF and JavaUDF too. _stateful_/_impliesOrder_ are not applicable to our Scala UDF. Thus, we only add the following two flags. - deterministic: Certain optimizations should not be applied if UDF is not deterministic. Deterministic UDF returns same result each time it is invoked with a particular input. This determinism just needs to hold within the context of a query. When the deterministic flag is not correctly set, the results could be wrong. For ScalaUDF in Dataset APIs, users can call the following extra APIs for `UserDefinedFunction` to make the corresponding changes. - `nonDeterministic`: Updates UserDefinedFunction to non-deterministic. Also fixed the Java UDF name loss issue. Will submit a separate PR for `distinctLike` for UDAF ### How was this patch tested? Added test cases for both ScalaUDF Author: gatorsmile Author: Wenchen Fan Closes #17848 from gatorsmile/udfRegister. --- python/pyspark/sql/context.py | 4 +- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/expressions/ScalaUDF.scala | 10 +- .../apache/spark/sql/UDFRegistration.scala | 243 ++++++++++-------- .../sql/expressions/UserDefinedFunction.scala | 48 +++- .../org/apache/spark/sql/functions.scala | 113 +++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 22 +- 7 files changed, 278 insertions(+), 164 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c44ab247fd3d3..b1e723cdecef3 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -220,11 +220,11 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): >>> sqlContext.registerJavaFunction("javaStringLength", ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF(test)=4)] + [Row(UDF:javaStringLength(test)=4)] >>> sqlContext.registerJavaFunction("javaStringLength2", ... "test.org.apache.spark.sql.JavaStringLength") >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF(test)=4)] + [Row(UDF:javaStringLength2(test)=4)] """ jdt = None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 501e7e3c6961d..913d846a8c23b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1950,7 +1950,7 @@ class Analyzer( case p => p transformExpressionsUp { - case udf @ ScalaUDF(func, _, inputs, _, _, _) => + case udf @ ScalaUDF(func, _, inputs, _, _, _, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a54f6d0e11147..9df0e2e1415c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.types.DataType /** * User-defined function. - * Note that the user-defined functions must be deterministic. * @param function The user defined scala function to run. * Note that if you use primitive parameters, you are not able to check if it is * null or not, and the UDF will return null for you if the primitive input is @@ -35,8 +34,10 @@ import org.apache.spark.sql.types.DataType * not want to perform coercion, simply use "Nil". Note that it would've been * better to use Option of Seq[DataType] so we can use "None" as the case for no * type coercion. However, that would require more refactoring of the codebase. - * @param udfName The user-specified name of this UDF. + * @param udfName The user-specified name of this UDF. * @param nullable True if the UDF can return null value. + * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result + * each time it is invoked with a particular input. */ case class ScalaUDF( function: AnyRef, @@ -44,9 +45,12 @@ case class ScalaUDF( children: Seq[Expression], inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None, - nullable: Boolean = true) + nullable: Boolean = true, + udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression { + override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override def toString: String = s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index c66d4057b9135..737afb4ac564e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -64,7 +64,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined aggregate function (UDAF). + * Registers a user-defined aggregate function (UDAF). * * @param name the name of the UDAF. * @param udaf the UDAF needs to be registered. @@ -79,8 +79,19 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function (UDF), for a UDF that's already defined using the DataFrame - * API (i.e. of type UserDefinedFunction). + * Registers a user-defined function (UDF), for a UDF that's already defined using the Dataset + * API (i.e. of type UserDefinedFunction). To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. To change a UDF to nonNullable, call the API + * `UserDefinedFunction.asNonNullabe()`. + * + * Example: + * {{{ + * val foo = udf(() => Math.random()) + * spark.udf.register("random", foo.asNondeterministic()) + * + * val bar = udf(() => "bar") + * spark.udf.register("stringLit", bar.asNonNullabe()) + * }}} * * @param name the name of the UDF. * @param udf the UDF needs to be registered. @@ -104,7 +115,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" /** - * Register a Scala closure of ${x} arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -112,13 +123,14 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: $x; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() }""") } @@ -137,7 +149,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | val func = f$anyCast.call($anyParams) | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF($funcCall, returnType, e) + | ScalaUDF($funcCall, returnType, e, udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -148,7 +160,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ /** - * Register a Scala closure of 0 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 0 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -156,17 +168,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 1 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 1 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -174,17 +187,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 2 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 2 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -192,17 +206,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 3 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 3 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -210,17 +225,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 4 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 4 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -228,17 +244,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 5 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 5 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -246,17 +263,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 6 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 6 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -264,17 +282,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 7 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 7 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -282,17 +301,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 8 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 8 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -300,17 +320,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 9 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 9 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -318,17 +339,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 10 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 10 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -336,17 +358,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 11 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 11 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -354,17 +377,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 12 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 12 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -372,17 +396,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 13 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 13 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -390,17 +415,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 14 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 14 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -408,17 +434,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 15 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 15 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -426,17 +453,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 16 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 16 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -444,17 +472,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 17 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 17 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -462,17 +491,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 18 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 18 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -480,17 +510,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 19 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 19 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -498,17 +529,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 20 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 20 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -516,17 +548,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 21 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 21 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -534,17 +567,18 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } /** - * Register a Scala closure of 22 arguments as user-defined function (UDF). + * Registers a deterministic Scala closure of 22 arguments as user-defined function (UDF). * @tparam RT return type of UDF. * @since 1.3.0 */ @@ -552,13 +586,14 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) + val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + if (nullable) udf else udf.asNonNullabe() } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -581,9 +616,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends .map(_.asInstanceOf[ParameterizedType]) .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) if (udfInterfaces.length == 0) { - throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface") + throw new AnalysisException(s"UDF class $className doesn't implement any UDF interface") } else if (udfInterfaces.length > 1) { - throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class $className") } else { try { val udf = clazz.newInstance() @@ -618,15 +653,15 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case n => - throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.") + throw new AnalysisException(s"UDF class with $n type arguments is not supported.") } } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => - throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new AnalysisException(s"Can not instantiate class $className, please make sure it has public non argument constructor") } } } catch { - case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class $className, please make sure it is on the classpath") } } @@ -659,7 +694,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF0[_], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(() => func, returnType, e) + ScalaUDF(() => func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -674,7 +709,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -689,7 +724,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -704,7 +739,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -719,7 +754,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -734,7 +769,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -749,7 +784,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -764,7 +799,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -779,7 +814,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -794,7 +829,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -809,7 +844,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -824,7 +859,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -839,7 +874,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -854,7 +889,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -869,7 +904,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -884,7 +919,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -899,7 +934,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -914,7 +949,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -929,7 +964,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -944,7 +979,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -959,7 +994,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -974,7 +1009,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -989,7 +1024,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, returnType, e) + ScalaUDF(func, returnType, e, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 0c5f1b436591d..97b921a622636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.Column -import org.apache.spark.sql.functions import org.apache.spark.sql.types.DataType /** @@ -35,10 +34,6 @@ import org.apache.spark.sql.types.DataType * df.select( predict(df("score")) ) * }}} * - * @note The user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. - * * @since 1.3.0 */ @InterfaceStability.Stable @@ -49,6 +44,7 @@ case class UserDefinedFunction protected[sql] ( private var _nameOption: Option[String] = None private var _nullable: Boolean = true + private var _deterministic: Boolean = true /** * Returns true when the UDF can return a nullable value. @@ -57,6 +53,14 @@ case class UserDefinedFunction protected[sql] ( */ def nullable: Boolean = _nullable + /** + * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same + * input. + * + * @since 2.3.0 + */ + def deterministic: Boolean = _deterministic + /** * Returns an expression that invokes the UDF, using the given arguments. * @@ -69,13 +73,15 @@ case class UserDefinedFunction protected[sql] ( exprs.map(_.expr), inputTypes.getOrElse(Nil), udfName = _nameOption, - nullable = _nullable)) + nullable = _nullable, + udfDeterministic = _deterministic)) } private def copyAll(): UserDefinedFunction = { val udf = copy() udf._nameOption = _nameOption udf._nullable = _nullable + udf._deterministic = _deterministic udf } @@ -84,22 +90,38 @@ case class UserDefinedFunction protected[sql] ( * * @since 2.3.0 */ - def withName(name: String): this.type = { - this._nameOption = Option(name) - this + def withName(name: String): UserDefinedFunction = { + val udf = copyAll() + udf._nameOption = Option(name) + udf + } + + /** + * Updates UserDefinedFunction to non-nullable. + * + * @since 2.3.0 + */ + def asNonNullabe(): UserDefinedFunction = { + if (!nullable) { + this + } else { + val udf = copyAll() + udf._nullable = false + udf + } } /** - * Updates UserDefinedFunction with a given nullability. + * Updates UserDefinedFunction to nondeterministic. * * @since 2.3.0 */ - def withNullability(nullable: Boolean): UserDefinedFunction = { - if (nullable == _nullable) { + def asNondeterministic(): UserDefinedFunction = { + if (!_deterministic) { this } else { val udf = copyAll() - udf._nullable = nullable + udf._deterministic = false udf } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ebdeb42b0bfb1..ccff00e570dbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3185,8 +3185,10 @@ object functions { val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" /** - * Defines a user-defined function of ${x} arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of ${x} arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3194,15 +3196,18 @@ object functions { def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() }""") } */ /** - * Defines a user-defined function of 0 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 0 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3210,12 +3215,15 @@ object functions { def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 1 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3223,12 +3231,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 2 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3236,12 +3247,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 3 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3249,12 +3263,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 4 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3262,12 +3279,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 5 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3275,12 +3295,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 6 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 6 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3288,12 +3311,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 7 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 7 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3301,12 +3327,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 8 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 8 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3314,12 +3343,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 9 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 9 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3327,12 +3359,15 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } /** - * Defines a user-defined function of 10 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. + * Defines a deterministic user-defined function of 10 arguments as user-defined + * function (UDF). The data types are automatically inferred based on the function's + * signature. To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3340,15 +3375,17 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption - UserDefinedFunction(f, dataType, inputTypes).withNullability(nullable) + val udf = UserDefinedFunction(f, dataType, inputTypes) + if (nullable) udf else udf.asNonNullabe() } // scalastyle:on parameter.number // scalastyle:on line.size.limit /** - * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must - * specify the output data type, and there is no automatic input type coercion. + * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, + * the caller must specify the output data type, and there is no automatic input type coercion. + * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 335b882ace92a..7f1c009ca6e7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.types.DataTypes private case class FunctionResult(f1: String, f2: String) @@ -109,9 +112,22 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("select foo(5)").head().getInt(0) == 6) } - test("ZeroArgument UDF") { - spark.udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + test("ZeroArgument non-deterministic UDF") { + val foo = udf(() => Math.random()) + spark.udf.register("random0", foo.asNondeterministic()) + val df = sql("SELECT random0()") + assert(df.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df.head().getDouble(0) >= 0.0) + + val foo1 = foo.asNondeterministic() + val df1 = testData.select(foo1()) + assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df1.head().getDouble(0) >= 0.0) + + val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic() + val df2 = testData.select(bar()) + assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df2.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") {