diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c3e9fa33e63a6..5ceb36513f840 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -86,7 +86,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi * @param elseValue optional value for the else branch */ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression { + extends Expression with CodegenFallback { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue @@ -136,7 +136,16 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E } } + def shouldCodegen: Boolean = { + branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN + } + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + if (!shouldCodegen) { + // Fallback to interpreted mode if there are too many branches, as it may reach the + // 64K limit (limit on bytecode size for a single function). + return super[CodegenFallback].genCode(ctx, ev) + } // Generate code that looks like: // // condA = ... @@ -205,6 +214,9 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E /** Factory methods for CaseWhen. */ object CaseWhen { + // The maxium number of switches supported with codegen. + val MAX_NUM_CASES_FOR_CODEGEN = 20 + def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { CaseWhen(branches, Option(elseValue)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 37bfe98d3ab24..a76517a89cc4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -203,7 +203,7 @@ case class Literal protected (value: Any, dataType: DataType) case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - super.genCode(ctx, ev) + super[CodegenFallback].genCode(ctx, ev) } else { ev.isNull = "false" ev.value = s"${value}f" @@ -212,7 +212,7 @@ case class Literal protected (value: Any, dataType: DataType) case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - super.genCode(ctx, ev) + super[CodegenFallback].genCode(ctx, ev) } else { ev.isNull = "false" ev.value = s"${value}D" @@ -232,7 +232,7 @@ case class Literal protected (value: Any, dataType: DataType) "" // eval() version may be faster for non-primitive types case other => - super.genCode(ctx, ev) + super[CodegenFallback].genCode(ctx, ev) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b5413fbe2bbcc..260dfb3f42244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-13242: case-when expression with large number of branches (or cases)") { + val cases = 50 + val clauses = 20 + + // Generate an individual case + def generateCase(n: Int): (Expression, Expression) = { + val condition = (1 to clauses) + .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n"))) + .reduceLeft[Expression]((l, r) => Or(l, r)) + (condition, Literal(n)) + } + + val expression = CaseWhen((1 to cases).map(generateCase(_))) + + val plan = GenerateMutableProjection.generate(Seq(expression))() + val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val actual = plan(input).toSeq(Seq(expression.dataType)) + + assert(actual(0) == cases) + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 45578d50bfc0d..dd831e60cbf5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -416,6 +416,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true + case e: CaseWhen => e.shouldCodegen // CodegenFallback requires the input to be an InternalRow case e: CodegenFallback => false case _ => true