Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11725][SQL] correctly handle null inputs for UDF #9770

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,15 @@ trait ScalaReflection {
}
}

/**
* Returns classes of input parameters of scala function object.
*/
def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaladoc please

val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
assert(methods.length == 1)
methods.head.getParameterTypes
}

def typeOfObject: PartialFunction[Any, DataType] = {
// The data type can be determined without ambiguity.
case obj: Boolean => BooleanType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -85,6 +85,8 @@ class Analyzer(
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
Expand Down Expand Up @@ -1063,6 +1065,34 @@ class Analyzer(
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
}

/**
* Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
* null check. When user defines a UDF with primitive parameters, there is no way to tell if the
* primitive parameter is null or not, so here we assume the primitive input is null-propagatable
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.

case plan => plan transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)

val inputsNullCheck = parameterTypes.zip(inputs)
// TODO: skip null handling for not-nullable primitive inputs after we can completely
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After an offline discussion with @davies , we think it's dangerous to completely trust the nullable information and optimize based on it, especially for 1.6 release. Maybe we can do it after 1.6.

cc @marmbrus

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the fact that most of the common code passes are not using nullable (for example generated expression, join), it could have some corner cases that the nullable is not generated correctly (for some data sources), I think it's risky for 1.6.

I'd vote to do that in next release (consider nullable in most places)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To play devils advocate, I think when the info is wrong is usually likely to be too conservative (allow nulls when there are none). Also, I'm not really sure what is going to change between now and 1.7 (i.e. if there are bugs we need to find them eventually).

That said, I'm fine waiting, but we should use this info eventually given the amount of effort we spend passing it around.

// trust the `nullable` information.
// .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
.filter { case (cls, _) => cls.isPrimitive }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType

/**
* User-defined function.
* @param function The user defined scala function to run.
* Note that if you use primitive parameters, you are not able to check if it is
* null or not, and the UDF will return null for you if the primitive input is
* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
* @param children The input expressions of this UDF.
* @param inputTypes The expected input types of this UDF.
*/
case class ScalaUDF(
function: AnyRef,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType))
}
}

test("get parameter type from a function object") {
val primitiveFunc = (i: Int, j: Long) => "x"
val primitiveTypes = getParameterTypes(primitiveFunc)
assert(primitiveTypes.forall(_.isPrimitive))
assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))

val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
val boxedTypes = getParameterTypes(boxedFunc)
assert(boxedTypes.forall(!_.isPrimitive))
assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long]))

val anyFunc = (i: Any, j: AnyRef) => "x"
val anyTypes = getParameterTypes(anyFunc)
assert(anyTypes.forall(!_.isPrimitive))
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest {
)
assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
}

test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
val string = testRelation2.output(0)
val double = testRelation2.output(2)
val short = testRelation2.output(4)
val nullResult = Literal.create(null, StringType)

def checkUDF(udf: Expression, transformed: Expression): Unit = {
checkAnalysis(
Project(Alias(udf, "")() :: Nil, testRelation2),
Project(Alias(transformed, "")() :: Nil, testRelation2)
)
}

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
val expected2 = If(IsNull(double), nullResult, udf2)
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
udf3)
checkUDF(udf3, expected3)

// we can skip special null handling for primitive parameters that are not nullable
// TODO: this is disabled for now as we can not completely trust `nullable`.
val udf4 = ScalaUDF(
(s: Short, d: Double) => "x",
StringType,
short :: double.withNullability(false) :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
udf4)
// checkUDF(udf4, expected4)
}
}
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select(df("*")), Row(1, "a"))
checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a"))
}

test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
val df = Seq(
new java.lang.Integer(22) -> "John",
null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name")

val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
(i: java.lang.Integer) => if (i == null) null else i * 2
}
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil)

val primitiveUDF = udf((i: Int) => i * 2)
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
}
}