From 90dd5f15a554560f25e812da3796ef711fd630f6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 2 Jan 2016 00:15:00 -0800 Subject: [PATCH] join push through unionall --- .../sql/catalyst/optimizer/Optimizer.scala | 25 +++++++++++++++++-- .../apache/spark/sql/DataFrameJoinSuite.scala | 18 +++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8b..6b8ccd11ef15d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -97,7 +97,7 @@ object SamplePushDown extends Rule[LogicalPlan] { * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is - * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * safe to pushdown Join, Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * * Intersect: @@ -129,7 +129,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { - case a: Attribute => rewrites(a) + case a: Attribute if rewrites.contains(a) => rewrites(a) } // We must promise the compiler that we did not discard the names in the case of project @@ -164,6 +164,27 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { ) ) + // Push down deterministic join predicate through UNION ALL + case j @ Join(u @ Union(uLeft, uRight), jRight, joinType, condition) => + if (condition.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Join(uLeft, jRight, joinType, condition), + Join(uRight, jRight, joinType, condition.map(pushToRight(_, rewrites)))) + } else { + j + } + + case j @ Join(jLeft, u @ Union(uLeft, uRight), joinType, condition) => + if (condition.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Join(jLeft, uLeft, joinType, condition), + Join(jLeft, uRight, joinType, condition.map(pushToRight(_, rewrites)))) + } else { + j + } + // Push down deterministic projection through UNION ALL case p @ Project(projectList, u @ Union(left, right)) => if (projectList.forall(_.deterministic)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39a65413bd592..142f31ed5eb12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -140,4 +140,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(df1.join(broadcast(pf1)).count() === 4) } } + + test("join - join unionALL") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int_df2", "int2_df2", "str_df2") + val df3 = Seq((1, "1"), (5, "5")).toDF("int", "str") + + checkAnswer( + df3.join(df.unionAll(df2), Seq("int", "str"), "inner"), + Row(1, "1", 2) :: + Row(1, "1", 3) :: + Row(5, "5", 6) :: Nil) + + checkAnswer( + df.unionAll(df2).join(df3, Seq("int", "str"), "inner"), + Row(1, "1", 2) :: + Row(1, "1", 3) :: + Row(5, "5", 6) :: Nil) + } }