Skip to content

Commit

Permalink
[SPARK-20586][SQL] Add deterministic to ScalaUDF
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Like [Hive UDFType](https://hive.apache.org/javadocs/r2.0.1/api/org/apache/hadoop/hive/ql/udf/UDFType.html), we should allow users to add the extra flags for ScalaUDF and JavaUDF too. _stateful_/_impliesOrder_ are not applicable to our Scala UDF. Thus, we only add the following two flags.

- deterministic: Certain optimizations should not be applied if UDF is not deterministic. Deterministic UDF returns same result each time it is invoked with a particular input. This determinism just needs to hold within the context of a query.

When the deterministic flag is not correctly set, the results could be wrong.

For ScalaUDF in Dataset APIs, users can call the following extra APIs for `UserDefinedFunction` to make the corresponding changes.
- `nonDeterministic`: Updates UserDefinedFunction to non-deterministic.

Also fixed the Java UDF name loss issue.

Will submit a separate PR for `distinctLike`  for UDAF

### How was this patch tested?
Added test cases for both ScalaUDF

Author: gatorsmile <gatorsmile@gmail.com>
Author: Wenchen Fan <cloud0fan@gmail.com>

Closes #17848 from gatorsmile/udfRegister.
  • Loading branch information
gatorsmile committed Jul 26, 2017
1 parent 9b4da7b commit ebc24a9
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 164 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/sql/context.py
Expand Up @@ -220,11 +220,11 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
>>> sqlContext.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
[Row(UDF(test)=4)]
[Row(UDF:javaStringLength(test)=4)]
>>> sqlContext.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF(test)=4)]
[Row(UDF:javaStringLength2(test)=4)]
"""
jdt = None
Expand Down
Expand Up @@ -1950,7 +1950,7 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _, _, _) =>
case udf @ ScalaUDF(func, _, inputs, _, _, _, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)

Expand Down
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.types.DataType

/**
* User-defined function.
* Note that the user-defined functions must be deterministic.
* @param function The user defined scala function to run.
* Note that if you use primitive parameters, you are not able to check if it is
* null or not, and the UDF will return null for you if the primitive input is
Expand All @@ -35,18 +34,23 @@ import org.apache.spark.sql.types.DataType
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
* type coercion. However, that would require more refactoring of the codebase.
* @param udfName The user-specified name of this UDF.
* @param udfName The user-specified name of this UDF.
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
*/
case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true)
nullable: Boolean = true,
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression {

override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

override def toString: String =
s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})"

Expand Down
243 changes: 139 additions & 104 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Large diffs are not rendered by default.

Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.expressions
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -35,10 +34,6 @@ import org.apache.spark.sql.types.DataType
* df.select( predict(df("score")) )
* }}}
*
* @note The user-defined functions must be deterministic. Due to optimization,
* duplicate invocations may be eliminated or the function may even be invoked more times than
* it is present in the query.
*
* @since 1.3.0
*/
@InterfaceStability.Stable
Expand All @@ -49,6 +44,7 @@ case class UserDefinedFunction protected[sql] (

private var _nameOption: Option[String] = None
private var _nullable: Boolean = true
private var _deterministic: Boolean = true

/**
* Returns true when the UDF can return a nullable value.
Expand All @@ -57,6 +53,14 @@ case class UserDefinedFunction protected[sql] (
*/
def nullable: Boolean = _nullable

/**
* Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
* input.
*
* @since 2.3.0
*/
def deterministic: Boolean = _deterministic

/**
* Returns an expression that invokes the UDF, using the given arguments.
*
Expand All @@ -69,13 +73,15 @@ case class UserDefinedFunction protected[sql] (
exprs.map(_.expr),
inputTypes.getOrElse(Nil),
udfName = _nameOption,
nullable = _nullable))
nullable = _nullable,
udfDeterministic = _deterministic))
}

private def copyAll(): UserDefinedFunction = {
val udf = copy()
udf._nameOption = _nameOption
udf._nullable = _nullable
udf._deterministic = _deterministic
udf
}

Expand All @@ -84,22 +90,38 @@ case class UserDefinedFunction protected[sql] (
*
* @since 2.3.0
*/
def withName(name: String): this.type = {
this._nameOption = Option(name)
this
def withName(name: String): UserDefinedFunction = {
val udf = copyAll()
udf._nameOption = Option(name)
udf
}

/**
* Updates UserDefinedFunction to non-nullable.
*
* @since 2.3.0
*/
def asNonNullabe(): UserDefinedFunction = {
if (!nullable) {
this
} else {
val udf = copyAll()
udf._nullable = false
udf
}
}

/**
* Updates UserDefinedFunction with a given nullability.
* Updates UserDefinedFunction to nondeterministic.
*
* @since 2.3.0
*/
def withNullability(nullable: Boolean): UserDefinedFunction = {
if (nullable == _nullable) {
def asNondeterministic(): UserDefinedFunction = {
if (!_deterministic) {
this
} else {
val udf = copyAll()
udf._nullable = nullable
udf._deterministic = false
udf
}
}
Expand Down

0 comments on commit ebc24a9

Please sign in to comment.