From 68c89866962b836f479a3fc41fbd503f8bc7ff47 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 17 Jan 2015 16:10:20 -0800 Subject: [PATCH] [SQL][Minor] Added comments and examples to explain BooleanSimplification. --- .../sql/catalyst/optimizer/Optimizer.scala | 177 ++++++++++-------- 1 file changed, 94 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 522f14b0917e8..81bb012ac6d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -302,89 +302,100 @@ object OptimizeIn extends Rule[LogicalPlan] { object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case and @ And(left, right) => - (left, right) match { - case (Literal(true, BooleanType), r) => r - case (l, Literal(true, BooleanType)) => l - case (Literal(false, BooleanType), _) => Literal(false) - case (_, Literal(false, BooleanType)) => Literal(false) - // a && a => a - case (l, r) if l fastEquals r => l - case (_, _) => - /* Do optimize for predicates using formula (a || b) && (a || c) => a || (b && c) - * 1. Split left and right to get the disjunctive predicates, - * i.e. lhsSet = (a, b), rhsSet = (a, c) - * 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - * 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - * 4. Apply the formula, get the optimized predict: common || (ldiff && rdiff) - */ - val lhsSet = splitDisjunctivePredicates(left).toSet - val rhsSet = splitDisjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) - if (ldiff.size == 0 || rdiff.size == 0) { - // a && (a || b) => a - common.reduce(Or) - } else { - // (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... => - // (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...) - (ldiff.reduceOption(Or) ++ rdiff.reduceOption(Or)) - .reduceOption(And) - .map(_ :: common.toList) - .getOrElse(common.toList) - .reduce(Or) - } - } - - case or @ Or(left, right) => - (left, right) match { - case (Literal(true, BooleanType), _) => Literal(true) - case (_, Literal(true, BooleanType)) => Literal(true) - case (Literal(false, BooleanType), r) => r - case (l, Literal(false, BooleanType)) => l - // a || a => a - case (l, r) if l fastEquals r => l - case (_, _) => - /* Do optimize for predicates using formula (a && b) || (a && c) => a && (b || c) - * 1. Split left and right to get the conjunctive predicates, - * i.e. lhsSet = (a, b), rhsSet = (a, c) - * 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - * 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - * 4. Apply the formula, get the optimized predict: common && (ldiff || rdiff) - */ - val lhsSet = splitConjunctivePredicates(left).toSet - val rhsSet = splitConjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) - if ( ldiff.size == 0 || rdiff.size == 0) { - // a || (b && a) => a - common.reduce(And) - } else { - // (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... => - // a && b && ((c && ...) || (d && ...) || (e && ...) || ...) - (ldiff.reduceOption(And) ++ rdiff.reduceOption(And)) - .reduceOption(Or) - .map(_ :: common.toList) - .getOrElse(common.toList) - .reduce(And) - } - } - - case not @ Not(exp) => - exp match { - case Literal(true, BooleanType) => Literal(false) - case Literal(false, BooleanType) => Literal(true) - case GreaterThan(l, r) => LessThanOrEqual(l, r) - case GreaterThanOrEqual(l, r) => LessThan(l, r) - case LessThan(l, r) => GreaterThanOrEqual(l, r) - case LessThanOrEqual(l, r) => GreaterThan(l, r) - case Not(e) => e - case _ => not - } - - // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". + case and @ And(left, right) => (left, right) match { + // true && r => r + case (Literal(true, BooleanType), r) => r + // l && true => l + case (l, Literal(true, BooleanType)) => l + // false && r => false + case (Literal(false, BooleanType), _) => Literal(false) + // l && false => false + case (_, Literal(false, BooleanType)) => Literal(false) + // a && a => a + case (l, r) if l fastEquals r => l + // (a || b) && (a || c) => a || (b && c) + case (_, _) => + // 1. Split left and right to get the disjunctive predicates, + // i.e. lhsSet = (a, b), rhsSet = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) + val lhsSet = splitDisjunctivePredicates(left).toSet + val rhsSet = splitDisjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + val ldiff = lhsSet.diff(common) + val rdiff = rhsSet.diff(common) + if (ldiff.size == 0 || rdiff.size == 0) { + // a && (a || b) => a + common.reduce(Or) + } else { + // (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... => + // (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...) + (ldiff.reduceOption(Or) ++ rdiff.reduceOption(Or)) + .reduceOption(And) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(Or) + } + } // end of And(left, right) + + case or @ Or(left, right) => (left, right) match { + // true || r => true + case (Literal(true, BooleanType), _) => Literal(true) + // r || true => true + case (_, Literal(true, BooleanType)) => Literal(true) + // false || r => r + case (Literal(false, BooleanType), r) => r + // l || false => l + case (l, Literal(false, BooleanType)) => l + // a || a => a + case (l, r) if l fastEquals r => l + // (a && b) || (a && c) => a && (b || c) + case (_, _) => + // 1. Split left and right to get the conjunctive predicates, + // i.e. lhsSet = (a, b), rhsSet = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) + val lhsSet = splitConjunctivePredicates(left).toSet + val rhsSet = splitConjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + val ldiff = lhsSet.diff(common) + val rdiff = rhsSet.diff(common) + if ( ldiff.size == 0 || rdiff.size == 0) { + // a || (b && a) => a + common.reduce(And) + } else { + // (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... => + // a && b && ((c && ...) || (d && ...) || (e && ...) || ...) + (ldiff.reduceOption(And) ++ rdiff.reduceOption(And)) + .reduceOption(Or) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(And) + } + } // end of Or(left, right) + + case not @ Not(exp) => exp match { + // not(true) => false + case Literal(true, BooleanType) => Literal(false) + // not(false) => true + case Literal(false, BooleanType) => Literal(true) + // not(l > r) => l <= r + case GreaterThan(l, r) => LessThanOrEqual(l, r) + // not(l >= r) => l < r + case GreaterThanOrEqual(l, r) => LessThan(l, r) + // not(l < r) => l >= r + case LessThan(l, r) => GreaterThanOrEqual(l, r) + // not(l <= r) => l > r + case LessThanOrEqual(l, r) => GreaterThan(l, r) + // not(not(e)) => e + case Not(e) => e + case _ => not + } // end of Not(exp) + + // if (true) a else b => a + // if (false) a else b => b case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } }