Skip to content

Commit

Permalink
[SPARK-12258] [SQL] passing null into ScalaUDF (follow-up)
Browse files Browse the repository at this point in the history
This is a follow-up PR for #10259

Author: Davies Liu <davies@databricks.com>

Closes #10266 from davies/null_udf2.
  • Loading branch information
Davies Liu authored and davies committed Dec 11, 2015
1 parent 518ab51 commit c119a34
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1029,24 +1029,27 @@ case class ScalaUDF(
// such as IntegerType, its javaType is `int` and the returned type of user-defined
// function is Object. Trying to convert an Object to `int` will cause casting exception.
val evalCode = evals.map(_.code).mkString
val funcArguments = converterTerms.zipWithIndex.map {
case (converter, i) =>
val eval = evals(i)
val dt = children(i).dataType
s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})"
}.mkString(",")
val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +
s".apply($funcTerm.apply($funcArguments));"
val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
val eval = evals(i)
val argTerm = ctx.freshName("arg")
val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
(convert, argTerm)
}.unzip

evalCode + s"""
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
Boolean ${ev.isNull};
val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"

s"""
$evalCode
${converters.mkString("\n")}
$callFunc

${ev.value} = $resultTerm;
${ev.isNull} = $resultTerm == null;
boolean ${ev.isNull} = $resultTerm == null;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $resultTerm;
}
"""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {

// passing null into the UDF that could handle it
val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
(i: java.lang.Integer) => if (i == null) -10 else i * 2
(i: java.lang.Integer) => if (i == null) -10 else null
}
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil)
checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)

sqlContext.udf.register("boxedUDF",
(i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)

val primitiveUDF = udf((i: Int) => i * 2)
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
Expand Down

0 comments on commit c119a34

Please sign in to comment.