Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ed5e453
Add nullability support.
ptkool Aug 10, 2017
ca63b8f
Additional changes for panda UDFs
ptkool Nov 13, 2017
71c23de
Fix Python test failures
ptkool Nov 13, 2017
d2d06d3
Undo change to modules.py
ptkool Nov 14, 2017
90e1684
Resolve conflicts
ptkool Nov 30, 2017
0727597
Fix Python style tests
ptkool Nov 30, 2017
f8e4904
Resolve test failures
ptkool Nov 30, 2017
d992f93
Make Python API consistent with Scala API
ptkool Dec 11, 2017
ca52538
Fix test failures
ptkool Dec 28, 2017
9a11aa9
Clean up and small corrections based on feedback
ptkool Jan 3, 2018
0bb9472
More changes based on feedback
ptkool Jan 3, 2018
6348f05
Resolve conflicts
ptkool Jan 17, 2018
592324d
Resolve conflicts
ptkool Jan 18, 2018
5af8e57
Changes based on feedback
ptkool Jan 21, 2018
414545d
Fix Python style errors
ptkool Jan 21, 2018
3b72f0d
Documentation corrections
ptkool Jan 23, 2018
d52e1f0
Changed version added to 2.4
ptkool Jan 23, 2018
e94960c
Raise exceptions for non-nullable functions returning null values
ptkool Feb 28, 2018
e6e6dbf
Fix tests and Python style checks
ptkool Feb 28, 2018
64f0500
Merge branch 'master' into udf_nullability
ptkool Mar 18, 2018
cdd16a9
Resolve conflicts
ptkool Aug 12, 2018
9038520
Fix test failures
ptkool Sep 3, 2018
dcf3f07
Fix generated code compilation failures
ptkool Sep 4, 2018
97305f5
Fix failing ML tests
ptkool Sep 5, 2018
0377e28
Merge branch 'master' into udf_nullability
ptkool Sep 28, 2018
fac6f1e
Fix compilation failure when running Pandas tests
ptkool Sep 28, 2018
eebc18f
Rebase and add udf tests to proper modules
ptkool Jan 12, 2019
e1d68e8
Fix merge conflicts
ptkool Jul 31, 2019
be5735a
Fix compilation errors
ptkool Jul 31, 2019
9cfc5b6
Resolve merge conflict
ptkool Sep 7, 2019
b074eb1
Fix merge conflicts
ptkool Jan 18, 2020
516a708
Fix test failures
ptkool Jan 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2829,6 +2829,14 @@ def udf(f=None, returnType=StringType()):
>>> import random
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()

.. note:: The user-defined functions are considered to be able to return null values by default.
If your function is not nullable, call `asNonNullable` on the user defined function.
E.g.:

>>> from pyspark.sql.types import StringType
>>> import getpass
>>> getuser_udf = udf(lambda: getpass.getuser(), StringType()).asNonNullable()

.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, DoubleType()).asNonNullable()
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
self.assertFalse(udf.nullable)

def test_pandas_udf_decorator(self):
@pandas_udf(DoubleType())
def foo(x):
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,42 @@ def test_non_existed_udaf(self):
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))

def test_udf_no_nulls(self):
from pyspark.sql.functions import udf
plus_four = udf(lambda x: x + 4, IntegerType()).asNonNullable()
df = self.spark.range(10)
res = df.select(plus_four(df['id']).alias('plus_four'))
self.assertFalse(plus_four.nullable)
self.assertFalse(res.schema['plus_four'].nullable)
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)

def test_udf_with_callable_no_nulls(self):
df = self.spark.range(10)

class PlusFour:
def __call__(self, col):
if col is not None:
return col + 4
else:
return 0

call = PlusFour()
pudf = UserDefinedFunction(call, LongType()).asNonNullable()
res = df.select(pudf(df['id']).alias('plus_four'))
self.assertFalse(pudf.nullable)
self.assertFalse(res.schema['plus_four'].nullable)
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)

def test_udf_no_nulls_returns_null(self):
from pyspark.sql.functions import udf
plus_four = udf(lambda x: x + 4 if x > 0 else None, IntegerType()).asNonNullable()
df = self.spark.range(10)
res = df.select(plus_four(df['id']).alias('plus_four'))
self.assertFalse(plus_four.nullable)
self.assertFalse(res.schema['plus_four'].nullable)
with self.assertRaisesRegexp(Exception, "Cannot return null value.*"):
res.agg({'plus_four': 'sum'}).collect()

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import input_file_name
sourceFile = udf(lambda path: path, StringType())
Expand Down
25 changes: 20 additions & 5 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _create_udf(f, returnType, evalType):

# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
f, returnType=returnType, name=None, evalType=evalType, deterministic=True, nullable=True)
return udf_obj._wrapped()


Expand All @@ -93,7 +93,8 @@ def __init__(self, func,
returnType=StringType(),
name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF,
deterministic=True):
deterministic=True,
nullable=True):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
Expand All @@ -118,6 +119,7 @@ def __init__(self, func,
else func.__class__.__name__)
self.evalType = evalType
self.deterministic = deterministic
self.nullable = nullable

@property
def returnType(self):
Expand Down Expand Up @@ -202,7 +204,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
self._name, wrapped_func, jdt, self.evalType, self.deterministic, self.nullable)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -240,11 +242,14 @@ def wrapper(*args):
wrapper.deterministic = self.deterministic
wrapper.asNondeterministic = functools.wraps(
self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())
wrapper.asNonNullable = functools.wraps(
self.asNonNullable)(lambda: self.asNonNullable()._wrapped())
wrapper.nullable = self.nullable
return wrapper

def asNondeterministic(self):
"""
Updates UserDefinedFunction to nondeterministic.
Updates :class:`UserDefinedFunction` to nondeterministic.

.. versionadded:: 2.3
"""
Expand All @@ -254,6 +259,15 @@ def asNondeterministic(self):
self.deterministic = False
return self

def asNonNullable(self):
"""
Updates :class:`UserDefinedFunction` to non-nullable.

.. versionadded:: 2.4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this 2.3.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.0?

"""
self.nullable = False
return self


class UDFRegistration(object):
"""
Expand Down Expand Up @@ -371,7 +385,8 @@ def register(self, name, f, returnType=None):
"SQL_MAP_PANDAS_ITER_UDF.")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
deterministic=f.deterministic,
nullable=f.nullable)
return_udf = f
else:
if returnType is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ case class PythonUDF(
children: Seq[Expression],
evalType: Int,
udfDeterministic: Boolean,
udfNullable: Boolean = true,
resultId: ExprId = NamedExpression.newExprId)
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {

Expand All @@ -66,7 +67,7 @@ case class PythonUDF(
lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)(
exprId = resultId)

override def nullable: Boolean = true
override def nullable: Boolean = udfNullable

override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,8 @@ case class ScalaUDF(
}.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType)
val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]")
val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage)
val errorMsgNonNullableTerm = ctx.addReferenceObj("errNonNullableMsg",
udfNonNullableErrorMessage)
val resultTerm = ctx.freshName("result")

// codegen for children expressions
Expand Down Expand Up @@ -1039,27 +1041,40 @@ case class ScalaUDF(
|}
""".stripMargin

val canBeNull = ctx.freshName("canBeNull")
val initNullable = s"boolean $canBeNull = $nullable;"

ev.copy(code =
code"""
|$evalCode
|${initArgs.mkString("\n")}
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
|$initNullable
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|} else if (!$canBeNull) {
| throw new RuntimeException($errorMsgNonNullableTerm);
|}
""".stripMargin)
}

private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType)

lazy val udfErrorMessage = {
lazy val functionSignature = {
val funcCls = function.getClass.getSimpleName
val inputTypes = children.map(_.dataType.catalogString).mkString(", ")
val outputType = dataType.catalogString
s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)"
val inputTypes = children.map(_.dataType.simpleString).mkString(", ")
s"user defined function ($funcCls: ($inputTypes) => ${dataType.simpleString})"
}

lazy val udfErrorMessage = {
s"Failed to execute $functionSignature"
}

lazy val udfNonNullableErrorMessage = {
s"Cannot return null value from $functionSignature"
}

override def eval(input: InternalRow): Any = {
Expand All @@ -1070,6 +1085,10 @@ case class ScalaUDF(
throw new SparkException(udfErrorMessage, e)
}

if (result == null && !nullable) {
throw new RuntimeException(udfNonNullableErrorMessage)
}

resultConverter(result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -34,6 +34,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, resultAttrs, child) {


protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
argOffsets: Array[Array[Int]],
Expand Down Expand Up @@ -83,13 +84,30 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
var row: InternalRow = null
if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = fromJava(result)
mutableRow
row = mutableRow
} else {
fromJava(result).asInstanceOf[InternalRow]
row = fromJava(result).asInstanceOf[InternalRow]
}

verifyResults(row)

row
}
}

private def verifyResults(row: InternalRow) {
for ((udf, i) <- udfs.view.zipWithIndex) {
if (row.isNullAt(i) && !udf.nullable) {
val inputTypes = udf.children.map(_.dataType.simpleString).mkString(", ")
val signature = s"${udf.name}: ($inputTypes) => ${udf.dataType.simpleString}"
throw new SparkException(
s"Cannot return null value from user defined function $signature.")
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ case class UserDefinedPythonFunction(
func: PythonFunction,
dataType: DataType,
pythonEvalType: Int,
udfDeterministic: Boolean) {
udfDeterministic: Boolean,
udfNullable: Boolean = true) {

def builder(e: Seq[Expression]): Expression = {
PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic, udfNullable)
}

/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
Expand Down
35 changes: 35 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import java.math.BigDecimal

import org.apache.spark.SparkException
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.plans.logical.Project
Expand Down Expand Up @@ -292,6 +293,40 @@ class UDFSuite extends QueryTest with SharedSparkSession {
""".stripMargin).toDF(), complexData.select("m", "a", "b"))
}

test("Non-nullable UDF") {
val foo = udf(() => Math.random())
spark.udf.register("random0", foo.asNonNullable())
val df = sql("SELECT random0()")
assert(df.logicalPlan.asInstanceOf[Project].projectList.forall(!_.nullable))
assert(df.head().getDouble(0) >= 0.0)

val foo1 = foo.asNonNullable()
val df1 = testData.select(foo1())
assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.nullable))
assert(df1.head().getDouble(0) >= 0.0)

val bar = udf(() => Math.random(), DataTypes.DoubleType).asNonNullable()
val df2 = testData.select(bar())
assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.nullable))
assert(df2.head().getDouble(0) >= 0.0)

val javaUdf = udf(new UDF0[Double] {
override def call(): Double = Math.random()
}, DoubleType).asNonNullable()
val df3 = testData.select(javaUdf())
assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.nullable))
assert(df3.head().getDouble(0) >= 0.0)
}

test("Non-nullable UDF returning null") {
val foo = udf(() => null).asNonNullable()
val df1 = testData.select(foo())
val e = intercept[SparkException] {
df1.head()
}
assert(e.getMessage.contains("Cannot return null value from user defined function"))
}

test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") {
val myUDF = spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) })

Expand Down