Skip to content

Commit

Permalink
[SPARK-33884][SQL] Simplify CaseWhenclauses with (true and false) and…
Browse files Browse the repository at this point in the history
… (false and true)

### What changes were proposed in this pull request?

This pr simplify `CaseWhen`clauses with (true and false) and (false and true):

Expression | cond.nullable | After simplify
-- | -- | --
case when cond then true else false end | true | cond <=> true
case when cond then true else false end | false | cond
case when cond then false else true end | true | !(cond <=> true)
case when cond then false else true end | false | !cond

### Why are the changes needed?

Improve query performance.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #30898 from wangyum/SPARK-33884.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wangyum authored and cloud-fan committed Dec 29, 2020
1 parent 379afcd commit f7bdea3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l)
case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l)

case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) =>
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) =>
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)

case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class PushFoldableIntoBranchesSuite
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2)))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral)))
GreaterThanOrEqual(Rand(1), Literal(0.5)))
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral)))

Expand Down Expand Up @@ -269,4 +269,13 @@ class PushFoldableIntoBranchesSuite
Literal.create(null, BooleanType))
}
}

test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
assertEquivalent(
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)),
'a > 10 <=> TrueLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)),
Not('a > 10 <=> TrueLiteral))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,40 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
Literal.create(null, IntegerType))
}
}

test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
// verify the boolean equivalence of all transformations involved
val fields = Seq(
'cond.boolean.notNull,
'cond_nullable.boolean,
'a.boolean,
'b.boolean
)
val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) }

val exprs = Seq(
// actual expressions of the transformations: original -> transformed
CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral) -> cond,
CaseWhen(Seq((cond, FalseLiteral)), TrueLiteral) -> !cond,
CaseWhen(Seq((cond_nullable, TrueLiteral)), FalseLiteral) -> (cond_nullable <=> true),
CaseWhen(Seq((cond_nullable, FalseLiteral)), TrueLiteral) -> (!(cond_nullable <=> true)))

// check plans
for ((originalExpr, expectedExpr) <- exprs) {
assertEquivalent(originalExpr, expectedExpr)
}

// check evaluation
val binaryBooleanValues = Seq(true, false)
val ternaryBooleanValues = Seq(true, false, null)
for (condVal <- binaryBooleanValues;
condNullableVal <- ternaryBooleanValues;
aVal <- ternaryBooleanValues;
bVal <- ternaryBooleanValues;
(originalExpr, expectedExpr) <- exprs) {
val inputRow = create_row(condVal, condNullableVal, aVal, bVal)
val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
checkEvaluation(originalExpr, optimizedVal, inputRow)
}
}
}

0 comments on commit f7bdea3

Please sign in to comment.