Skip to content

Commit

Permalink
Improves compareConditions to handle more subtle cases
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Jan 18, 2015
1 parent 1bf3258 commit cd8860b
Showing 1 changed file with 16 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions.{Or, And, Literal, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class BooleanSimplificationSuite extends PlanTest {
class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand All @@ -40,14 +40,21 @@ class BooleanSimplificationSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)

// The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match {
case (And(l1, l2), And(r1, r2)) =>
compareConditions(l1, r1) && compareConditions(l2, r2) ||
compareConditions(l1, r2) && compareConditions(l2, r1)

case (Or(l1, l2), Or(r1, r2)) =>
compareConditions(l1, r1) && compareConditions(l2, r2) ||
compareConditions(l1, r2) && compareConditions(l2, r1)
case (lhs: And, rhs: And) =>
val lhsSet = splitConjunctivePredicates(lhs).toSet
val rhsSet = splitConjunctivePredicates(rhs).toSet
lhsSet.foldLeft(rhsSet) { (set, e) =>
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
}.isEmpty

case (lhs: Or, rhs: Or) =>
val lhsSet = splitDisjunctivePredicates(lhs).toSet
val rhsSet = splitDisjunctivePredicates(rhs).toSet
lhsSet.foldLeft(rhsSet) { (set, e) =>
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
}.isEmpty

case (l, r) => l == r
}
Expand Down

0 comments on commit cd8860b

Please sign in to comment.