Skip to content

Commit

Permalink
adding test case
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Nov 24, 2017
1 parent 98eaae9 commit 6225c8e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
Expand Up @@ -261,12 +261,14 @@ case class CaseWhen(
${ev.value} = ${res.value};
}
"""
}.getOrElse("")
}

val allConditions = cases ++ elseCode

val casesCode = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
cases.mkString("\n")
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
allConditions.mkString("\n")
} else {
ctx.splitExpressions(cases, "caseWhen",
ctx.splitExpressions(allConditions, "caseWhen",
("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean",
makeSplitFunction = {
func =>
Expand All @@ -284,8 +286,7 @@ case class CaseWhen(
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
boolean $conditionMet = false;
$casesCode
$elseCode""")
$code""")
}
}

Expand Down
Expand Up @@ -29,7 +29,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -2126,4 +2126,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val mean = result.select("DecimalCol").where($"summary" === "mean")
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
}

test("SPARK-22520: support code generation for large CaseWhen") {
val N = 30
var expr1 = when($"id" === lit(0), 0)
var expr2 = when($"id" === lit(0), 10)
(1 to N).foreach { i =>
expr1 = expr1.when($"id" === lit(i), -i)
expr2 = expr2.when($"id" === lit(i + 10), i)
}
val df = spark.range(1).select(expr1, expr2.otherwise(0))
df.show
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
}
}

0 comments on commit 6225c8e

Please sign in to comment.