From 4ac3d1efdec05d56f5051901ca209b67b23be28b Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 27 Jul 2018 16:47:02 -0700 Subject: [PATCH 1/3] prune casewhen branch --- .../sql/catalyst/optimizer/expressions.scala | 23 ++++++++ .../optimizer/SimplifyConditionalSuite.scala | 54 ++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4696699337c9d..003f5e65d5e75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -416,6 +416,29 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) + + case e @ CaseWhen(branches, _) => + val newBranches = branches.foldLeft(List[(Expression, Expression)]()) { + case (newBranches, branch) => + if (newBranches.exists(_._1.semanticEquals(branch._1))) { + // If a condition in a branch is previously seen, this branch can be pruned. + // TODO: In fact, if a condition is a sub-condition of the previous one, + // TODO: it can be pruned. This is less strict and can be implemented + // TODO: by decomposing seen conditions. + newBranches + } else if (newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2)) { + // If the outputs of two adjacent branches are the same, two branches can be combined. + newBranches.take(newBranches.length - 1) + .:+((Or(newBranches.last._1, branch._1), newBranches.last._2)) + } else { + newBranches.:+(branch) + } + } + if (newBranches.length < branches.length) { + e.copy(branches = newBranches) + } else { + e + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index e210874a55d87..a1c69012d3871 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -46,7 +46,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) - private val testRelation = LocalRelation('a.int) + val isNotNullCond = IsNotNull(UnresolvedAttribute("a")) + val isNullCond = IsNull(UnresolvedAttribute("a")) + val notCond = Not(UnresolvedAttribute("c")) test("simplify if") { assertEquivalent( @@ -122,4 +124,54 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { None), CaseWhen(normalBranch :: trueBranch :: Nil, None)) } + + test("remove a branch in CaseWhen if a cond in this branch is previously seen") { + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (GreaterThan(Rand(0), Literal(0.5)), Literal(2)) :: + (NonFoldableLiteral(true), Literal(3)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(4)) :: + (NonFoldableLiteral(true), Literal(5)) :: + (NonFoldableLiteral(false), Literal(6)) :: + (NonFoldableLiteral(false), Literal(7)) :: + Nil, + None), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (GreaterThan(Rand(0), Literal(0.5)), Literal(2)) :: + (NonFoldableLiteral(true), Literal(3)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(4)) :: + (NonFoldableLiteral(false), Literal(6)) :: + Nil, + None) + ) + } + + test("combine two adjacent branches in CaseWhen if they have the same output values") { + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(3)) :: + (NonFoldableLiteral(true), Literal(3)) :: + (NonFoldableLiteral(false), Literal(4)) :: + Nil, + None), + CaseWhen((Or(GreaterThan(Rand(0), Literal(0.5)), NonFoldableLiteral(true)), Literal(1)) :: + (Or(LessThan(Rand(1), Literal(0.5)), NonFoldableLiteral(true)), Literal(3)) :: + (NonFoldableLiteral(false), Literal(4)) :: + Nil, + None) + ) + + // The first two conditions can be combined, and then the optimizer uses rule in `Or` + // to be optimized into `TrueLiteral`. Thus, the entire `CaseWhen` can be removed. + assertEquivalent( + CaseWhen((UnresolvedAttribute("a"), Literal(1)) :: + (Not(UnresolvedAttribute("a")), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(3)) :: + (NonFoldableLiteral(true), Literal(4)) :: + (NonFoldableLiteral(false), Literal(5)) :: + Nil, + None), + Literal(1)) + } } From 42449916dd9a13da7e024987588e4c01425f2332 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 30 Jul 2018 10:50:50 -0700 Subject: [PATCH 2/3] temp --- .../sql/catalyst/optimizer/expressions.scala | 44 ++++++++++++------- .../optimizer/SimplifyConditionalSuite.scala | 2 +- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 003f5e65d5e75..07c85aea74ebe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -417,22 +417,34 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) - case e @ CaseWhen(branches, _) => - val newBranches = branches.foldLeft(List[(Expression, Expression)]()) { - case (newBranches, branch) => - if (newBranches.exists(_._1.semanticEquals(branch._1))) { - // If a condition in a branch is previously seen, this branch can be pruned. - // TODO: In fact, if a condition is a sub-condition of the previous one, - // TODO: it can be pruned. This is less strict and can be implemented - // TODO: by decomposing seen conditions. - newBranches - } else if (newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2)) { - // If the outputs of two adjacent branches are the same, two branches can be combined. - newBranches.take(newBranches.length - 1) - .:+((Or(newBranches.last._1, branch._1), newBranches.last._2)) - } else { - newBranches.:+(branch) - } + case e @ CaseWhen(branches, _) if { + + true + } => + val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { + case (newBranches, branch) if newBranches.exists(_._1.semanticEquals(branch._1)) => + // If a condition in a branch is previously seen, this branch can be pruned. + // TODO: In fact, if a condition is a sub-condition of the previous one, + // TODO: it can be pruned. This is less strict and can be implemented + // TODO: by decomposing the seen conditions. + newBranches + case (newBranches, branch) => newBranches += branch + } + if (newBranches.length < branches.length) { + e.copy(branches = newBranches) + } else { + e + } + + case e @ CaseWhen(branches, _) if { + true + } => + val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { + case (newBranches, branch) + if newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2) => + // If the outputs of two adjacent branches are the same, two branches can be combined. + newBranches.init += ((Or(newBranches.last._1, branch._1), newBranches.last._2)) + case (newBranches, branch) => newBranches += branch } if (newBranches.length < branches.length) { e.copy(branches = newBranches) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index a1c69012d3871..afdbee9a59f32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -156,7 +156,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Nil, None), CaseWhen((Or(GreaterThan(Rand(0), Literal(0.5)), NonFoldableLiteral(true)), Literal(1)) :: - (Or(LessThan(Rand(1), Literal(0.5)), NonFoldableLiteral(true)), Literal(3)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(3)) :: (NonFoldableLiteral(false), Literal(4)) :: Nil, None) From 691564780af6745e6aa60594a4829ebf28535202 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 30 Jul 2018 13:36:27 -0700 Subject: [PATCH 3/3] refactoring --- .../sql/catalyst/optimizer/expressions.scala | 71 ++++++++++--------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 07c85aea74ebe..cf25c3bcfa8ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -385,6 +385,40 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case _ => false } + // If a condition in a branch is previously seen, this branch can be pruned. + // TODO: In fact, if a condition is a sub-condition of the previous one, + // TODO: it can be pruned. This is less strict and can be implemented + // TODO: by decomposing the seen conditions. + private def pruneSeenBranches(branches: Seq[(Expression, Expression)]) + : Option[Seq[(Expression, Expression)]] = { + val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { + case (newBranches, branch) if newBranches.exists(_._1.semanticEquals(branch._1)) => + newBranches + case (newBranches, branch) => newBranches += branch + } + if (newBranches.length < branches.length) { + Some(newBranches) + } else { + None + } + } + + // If the outputs of two adjacent branches are the same, two branches can be combined. + private def combineAdjacentBranches(branches: Seq[(Expression, Expression)]) + : Option[Seq[(Expression, Expression)]] = { + val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { + case (newBranches, branch) + if newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2) => + newBranches.init += ((Or(newBranches.last._1, branch._1), newBranches.last._2)) + case (newBranches, branch) => newBranches += branch + } + if (newBranches.length < branches.length) { + Some(newBranches) + } else { + None + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue @@ -417,40 +451,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) - case e @ CaseWhen(branches, _) if { - - true - } => - val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { - case (newBranches, branch) if newBranches.exists(_._1.semanticEquals(branch._1)) => - // If a condition in a branch is previously seen, this branch can be pruned. - // TODO: In fact, if a condition is a sub-condition of the previous one, - // TODO: it can be pruned. This is less strict and can be implemented - // TODO: by decomposing the seen conditions. - newBranches - case (newBranches, branch) => newBranches += branch - } - if (newBranches.length < branches.length) { - e.copy(branches = newBranches) - } else { - e - } + case e @ CaseWhen(branches, _) if pruneSeenBranches(branches).nonEmpty => + e.copy(branches = pruneSeenBranches(branches).get) - case e @ CaseWhen(branches, _) if { - true - } => - val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { - case (newBranches, branch) - if newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2) => - // If the outputs of two adjacent branches are the same, two branches can be combined. - newBranches.init += ((Or(newBranches.last._1, branch._1), newBranches.last._2)) - case (newBranches, branch) => newBranches += branch - } - if (newBranches.length < branches.length) { - e.copy(branches = newBranches) - } else { - e - } + case e @ CaseWhen(branches, _) if combineAdjacentBranches(branches).nonEmpty => + e.copy(branches = combineAdjacentBranches(branches).get) } } }