Skip to content

Commit

Permalink
[SPARK-18467][SQL] Extracts method for preparing arguments from Stati…
Browse files Browse the repository at this point in the history
…cInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`.

## What changes were proposed in this pull request?

This pr extracts method for preparing arguments from `StaticInvoke`, `Invoke` and `NewInstance` and modify to short circuit if arguments have `null` when `propageteNull == true`.

The steps are as follows:

1. Introduce `InvokeLike` to extract common logic from `StaticInvoke`, `Invoke` and `NewInstance` to prepare arguments.
`StaticInvoke` and `Invoke` had a risk to exceed 64kb JVM limit to prepare arguments but after this patch they can handle them because they share the preparing code of NewInstance, which handles the limit well.

2. Remove unneeded null checking and fix nullability of `NewInstance`.
Avoid some of nullabilty checking which are not needed because the expression is not nullable.

3. Modify to short circuit if arguments have `null` when `needNullCheck == true`.
If `needNullCheck == true`, preparing arguments can be skipped if we found one of them is `null`, so modified to short circuit in the case.

## How was this patch tested?

Existing tests.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #15901 from ueshin/issues/SPARK-18467.
  • Loading branch information
ueshin authored and cloud-fan committed Nov 21, 2016
1 parent b625a36 commit 6585479
Showing 1 changed file with 101 additions and 62 deletions.
Expand Up @@ -32,6 +32,78 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._

/**
* Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
*/
trait InvokeLike extends Expression with NonSQLExpression {

def arguments: Seq[Expression]

def propagateNull: Boolean

protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)

/**
* Prepares codes for arguments.
*
* - generate codes for argument.
* - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments.
* - avoid some of nullabilty checking which are not needed because the expression is not
* nullable.
* - when needNullCheck == true, short circuit if we found one of arguments is null because
* preparing rest of arguments can be skipped in the case.
*
* @param ctx a [[CodegenContext]]
* @return (code to prepare arguments, argument string, result of argument null check)
*/
def prepareArguments(ctx: CodegenContext): (String, String, String) = {

val resultIsNull = if (needNullCheck) {
val resultIsNull = ctx.freshName("resultIsNull")
ctx.addMutableState("boolean", resultIsNull, "")
resultIsNull
} else {
"false"
}
val argValues = arguments.map { e =>
val argValue = ctx.freshName("argValue")
ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
argValue
}

val argCodes = if (needNullCheck) {
val reset = s"$resultIsNull = false;"
val argCodes = arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
val updateResultIsNull = if (e.nullable) {
s"$resultIsNull = ${expr.isNull};"
} else {
""
}
s"""
if (!$resultIsNull) {
${expr.code}
$updateResultIsNull
${argValues(i)} = ${expr.value};
}
"""
}
reset +: argCodes
} else {
arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
s"""
${expr.code}
${argValues(i)} = ${expr.value};
"""
}
}
val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)

(argCode, argValues.mkString(", "), resultIsNull)
}
}

/**
* Invokes a static function, returning the result. By default, any of the arguments being null
* will result in returning null instead of calling the function.
Expand All @@ -50,7 +122,7 @@ case class StaticInvoke(
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
propagateNull: Boolean = true) extends InvokeLike {

val objectName = staticObject.getName.stripSuffix("$")

Expand All @@ -62,16 +134,10 @@ case class StaticInvoke(

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")

val callFunc = s"$objectName.$functionName($argString)"
val (argCode, argString, resultIsNull) = prepareArguments(ctx)

val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = false;"
}
val callFunc = s"$objectName.$functionName($argString)"

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
Expand All @@ -82,9 +148,9 @@ case class StaticInvoke(
}

val code = s"""
${argGen.map(_.code).mkString("\n")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
$argCode
boolean ${ev.isNull} = $resultIsNull;
final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
ev.copy(code = code)
Expand All @@ -103,13 +169,15 @@ case class StaticInvoke(
* @param functionName The name of the method to call.
* @param dataType The expected return type of the function.
* @param arguments An optional list of expressions, whos evaluation will be passed to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
*/
case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
propagateNull: Boolean = true) extends InvokeLike {

override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
Expand All @@ -131,8 +199,8 @@ case class Invoke(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")

val (argCode, argString, resultIsNull) = prepareArguments(ctx)

val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty
Expand Down Expand Up @@ -164,28 +232,26 @@ case class Invoke(
"""
}

val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = ${obj.isNull};"
}

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluate
if (!${obj.isNull}) {
$argCode
${ev.isNull} = $resultIsNull;
if (!${ev.isNull}) {
$evaluate
}
$postNullCheck
}
$postNullCheck
"""
ev.copy(code = code)
}
Expand Down Expand Up @@ -223,10 +289,10 @@ case class NewInstance(
arguments: Seq[Expression],
propagateNull: Boolean,
dataType: DataType,
outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression {
outerPointer: Option[() => AnyRef]) extends InvokeLike {
private val className = cls.getName

override def nullable: Boolean = propagateNull
override def nullable: Boolean = needNullCheck

override def children: Seq[Expression] = arguments

Expand All @@ -245,52 +311,25 @@ case class NewInstance(

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val argIsNulls = ctx.freshName("argIsNulls")
ctx.addMutableState("boolean[]", argIsNulls,
s"$argIsNulls = new boolean[${arguments.size}];")
val argValues = arguments.zipWithIndex.map { case (e, i) =>
val argValue = ctx.freshName("argValue")
ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
argValue
}

val argCodes = arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
expr.code + s"""
$argIsNulls[$i] = ${expr.isNull};
${argValues(i)} = ${expr.value};
"""
}
val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)

val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))

var isNull = ev.isNull
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"""
boolean $isNull = false;
for (int idx = 0; idx < ${arguments.length}; idx++) {
if ($argIsNulls[idx]) { $isNull = true; break; }
}
"""
} else {
isNull = "false"
""
}
ev.isNull = resultIsNull

val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})"""
s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
s"new $className(${argValues.mkString(", ")})"
s"new $className($argString)"
}

val code = s"""
$argCode
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
ev.copy(code = code, isNull = isNull)
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
ev.copy(code = code)
}

override def toString: String = s"newInstance($cls)"
Expand Down

0 comments on commit 6585479

Please sign in to comment.