From f9b32d5d044a899529959ad5042f8cf95c789ea8 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 28 Aug 2018 14:18:05 +0800 Subject: [PATCH] left/right join support push down during-join predicates --- .../sql/catalyst/optimizer/Optimizer.scala | 13 +++++++--- .../InferFiltersFromConstraintsSuite.scala | 25 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 63a62cd0cbfe6..1bb5224631f29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1122,7 +1122,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ -object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { +object PushPredicateThroughJoin extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { /** * Splits join condition expressions or filter predicates (on a given join's output) into three * categories based on the attributes required to evaluate them. Note that we explicitly exclude @@ -1190,11 +1191,13 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the join filter into sub query scanning if applicable case j @ Join(left, right, joinType, joinCondition) => - val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = - split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) + val condition = joinCondition.map(splitConjunctivePredicates).getOrElse(Nil) + val additionalCondition = inferAdditionalConstraints(condition.toSet) joinType match { case _: InnerLike | LeftSemi => + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + split(condition, left, right) // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -1204,6 +1207,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, joinType, newJoinCond) case RightOuter => + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + split(condition ++ additionalCondition, left, right) // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -1212,6 +1217,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, RightOuter, newJoinCond) case LeftOuter | LeftAnti | ExistenceJoin(_) => + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + split(condition ++ additionalCondition, left, right) // push down the right side only join filter for right sub query val newLeft = left val newRight = rightJoinConditions. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e4671f0d1cce6..add674438d7a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -263,4 +263,29 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val y = testRelation.subquery('y) testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } + + test("SPARK-25259: left/right join support push down during-join predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val queryAnswers = Seq( + ( + x.join(y, joinType = LeftOuter, + condition = Some("x.b".attr === "y.b".attr && "x.b".attr === 1)), + x.join(y.where("b".attr.isNotNull && "b".attr === 1), joinType = LeftOuter, + condition = Some("x.b".attr === "y.b".attr && "x.b".attr === 1)) + ), + ( + x.join(y, joinType = RightOuter, + condition = Some("x.b".attr === "y.b".attr && "y.b".attr === 1)), + x.where("b".attr.isNotNull && "b".attr === 1).join(y, joinType = RightOuter, + condition = Some("x.b".attr === "y.b".attr && "y.b".attr === 1)) + ) + ) + + queryAnswers foreach { queryAnswerPair => + val optimized = Optimize.execute(queryAnswerPair._1.analyze) + comparePlans(optimized, queryAnswerPair._2.analyze) + } + } }