Skip to content

Commit

Permalink
Addressed feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
dbtsai committed Jul 24, 2018
1 parent a9c97ce commit 59fada7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)

case CaseWhen((cond, branchValue) :: Nil, elseValue) =>
If(cond, branchValue, elseValue.getOrElse(Literal(null, branchValue.dataType)))
case CaseWhen(branches, elseValue) if branches.length == 1 =>
// Using pattern matching like `CaseWhen((cond, branchValue) :: Nil, elseValue)` will not
// work since the implementation of `branches` can be `ArrayBuffer`. A full test is in
// "SPARK-24892: simplify `CaseWhen` to `If` when there is only one branch",
// `SQLQuerySuite.scala`.
val cond = branches.head._1
val trueValue = branches.head._2
val falseValue = elseValue.getOrElse(Literal(null, trueValue.dataType))
If(cond, trueValue, falseValue)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
}

private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val normalBranch1 = (NonFoldableLiteral(true), Literal(10))
private val normalBranch2 = (NonFoldableLiteral(false), Literal(3))
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal.create(null, NullType), Literal(30))

Expand All @@ -60,18 +61,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
If(normalBranch._1, normalBranch._2, Literal(null, normalBranch._2.dataType)))
CaseWhen(unreachableBranch :: normalBranch1 :: unreachableBranch ::
normalBranch2 :: nullBranch :: Nil, None),
CaseWhen(normalBranch1 :: normalBranch2 :: Nil, None))
}

test("simplify CaseWhen to If when there is only one branch") {
assertEquivalent(
CaseWhen(normalBranch :: Nil, None),
If(normalBranch._1, normalBranch._2, Literal(null, normalBranch._2.dataType)))
CaseWhen(normalBranch1 :: Nil, Some(Literal(30))),
If(normalBranch1._1, normalBranch1._2, Literal(30)))

assertEquivalent(
CaseWhen(normalBranch :: Nil, Some(Literal(30))),
If(normalBranch._1, normalBranch._2, Literal(30)))
CaseWhen(normalBranch1 :: Nil, None),
If(normalBranch1._1, normalBranch1._2, Literal(null, normalBranch1._2.dataType)))

assertEquivalent(
CaseWhen(unreachableBranch :: normalBranch1 :: unreachableBranch :: nullBranch :: Nil, None),
If(normalBranch1._1, normalBranch1._2, Literal(null, normalBranch1._2.dataType)))
}

test("remove entire CaseWhen if only the else branch is reachable") {
Expand All @@ -86,28 +92,28 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
CaseWhen(trueBranch :: normalBranch1 :: nullBranch :: Nil, None),
Literal(5))

// Test branch elimination and simplification in combination
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch1
:: Nil, None),
Literal(5))

// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
assertEquivalent(
CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
CaseWhen(normalBranch :: trueBranch :: Nil, None))
CaseWhen(normalBranch1 :: trueBranch :: normalBranch1 :: Nil, None),
CaseWhen(normalBranch1 :: trueBranch :: Nil, None))
}

test("simplify CaseWhen, prune branches following a definite true") {
assertEquivalent(
CaseWhen(normalBranch :: unreachableBranch ::
CaseWhen(normalBranch1 :: unreachableBranch ::
unreachableBranch :: nullBranch ::
trueBranch :: normalBranch ::
trueBranch :: normalBranch1 ::
Nil,
None),
CaseWhen(normalBranch :: trueBranch :: Nil, None))
CaseWhen(normalBranch1 :: trueBranch :: Nil, None))
}
}
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2813,4 +2813,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Seq(Row(3, 99, 1)))
}
}

test("SPARK-24892: simplify `CaseWhen` to `If` when there is only one branch") {
withTable("t") {
Seq(Some(1), null, Some(3)).toDF("a").write.saveAsTable("t")

val plan1 = sql("select case when a is null then 1 end col1 from t")
val plan2 = sql("select if(a is null, 1, null) col1 from t")

checkAnswer(plan1, Row(null) :: Row(1) :: Row(null) :: Nil)
comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan)
}
}
}

0 comments on commit 59fada7

Please sign in to comment.