Skip to content

Commit

Permalink
[SPARK-13242] [SQL] codegen fallback in case-when if there many branches
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

If there are many branches in a CaseWhen expression, the generated code could go above the 64K limit for single java method, will fail to compile. This PR change it to fallback to interpret mode if there are more than 20 branches.

This PR is based on #11243 and #11221, thanks to joehalliwell

Closes #11243
Closes #11221

## How was this patch tested?

Add a test with 50 branches.

Author: Davies Liu <davies@databricks.com>

Closes #11592 from davies/fix_when.
  • Loading branch information
Davies Liu authored and davies committed Mar 9, 2016
1 parent 53ba6d6 commit 9634e17
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
* @param elseValue optional value for the else branch
*/
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
extends Expression {
extends Expression with CodegenFallback {

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

Expand Down Expand Up @@ -136,7 +136,16 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
}
}

def shouldCodegen: Boolean = {
branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
}

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
if (!shouldCodegen) {
// Fallback to interpreted mode if there are too many branches, as it may reach the
// 64K limit (limit on bytecode size for a single function).
return super[CodegenFallback].genCode(ctx, ev)
}
// Generate code that looks like:
//
// condA = ...
Expand Down Expand Up @@ -205,6 +214,9 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
/** Factory methods for CaseWhen. */
object CaseWhen {

// The maxium number of switches supported with codegen.
val MAX_NUM_CASES_FOR_CODEGEN = 20

def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ case class Literal protected (value: Any, dataType: DataType)
case FloatType =>
val v = value.asInstanceOf[Float]
if (v.isNaN || v.isInfinite) {
super.genCode(ctx, ev)
super[CodegenFallback].genCode(ctx, ev)
} else {
ev.isNull = "false"
ev.value = s"${value}f"
Expand All @@ -212,7 +212,7 @@ case class Literal protected (value: Any, dataType: DataType)
case DoubleType =>
val v = value.asInstanceOf[Double]
if (v.isNaN || v.isInfinite) {
super.genCode(ctx, ev)
super[CodegenFallback].genCode(ctx, ev)
} else {
ev.isNull = "false"
ev.value = s"${value}D"
Expand All @@ -232,7 +232,7 @@ case class Literal protected (value: Any, dataType: DataType)
""
// eval() version may be faster for non-primitive types
case other =>
super.genCode(ctx, ev)
super[CodegenFallback].genCode(ctx, ev)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("SPARK-13242: case-when expression with large number of branches (or cases)") {
val cases = 50
val clauses = 20

// Generate an individual case
def generateCase(n: Int): (Expression, Expression) = {
val condition = (1 to clauses)
.map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n")))
.reduceLeft[Expression]((l, r) => Or(l, r))
(condition, Literal(n))
}

val expression = CaseWhen((1 to cases).map(generateCase(_)))

val plan = GenerateMutableProjection.generate(Seq(expression))()
val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
val actual = plan(input).toSeq(Seq(expression.dataType))

assert(actual(0) == cases)
}

test("test generated safe and unsafe projection") {
val schema = new StructType(Array(
StructField("a", StringType, true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru

private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
case e: CaseWhen => e.shouldCodegen
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true
Expand Down

0 comments on commit 9634e17

Please sign in to comment.