Skip to content

Commit

Permalink
Fix join pushdown.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Feb 8, 2016
1 parent 060b9b8 commit 00e7f39
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,19 @@ object EliminateSerialization extends Rule[LogicalPlan] {
*/
object LimitPushDown extends Rule[LogicalPlan] {

private def stipGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = {
private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = {
plan match {
case GlobalLimit(expr, child) => child
case _ => plan
}
}

private def buildUnionChild(limitExp: Expression, plan: LogicalPlan): LogicalPlan = {
private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = {
(limitExp, plan.maxRows) match {
case (IntegerLiteral(maxRow), Some(IntegerLiteral(childMaxRows))) if maxRow < childMaxRows =>
LocalLimit(limitExp, stipGlobalLimitIfPresent(plan))
LocalLimit(limitExp, stripGlobalLimitIfPresent(plan))
case (_, None) =>
LocalLimit(limitExp, stipGlobalLimitIfPresent(plan))
LocalLimit(limitExp, stripGlobalLimitIfPresent(plan))
case _ => plan
}
}
Expand All @@ -160,7 +160,15 @@ object LimitPushDown extends Rule[LogicalPlan] {
// pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to
// pushdown Limit.
case LocalLimit(exp, Union(children)) =>
LocalLimit(exp, Union(children.map(buildUnionChild(exp, _))))
LocalLimit(exp, Union(children.map(maybePushLimit(exp, _))))
case LocalLimit(exp, join @ Join(left, right, joinType, condition)) =>
joinType match {
case RightOuter => join.copy(right = maybePushLimit(exp, right))
case LeftOuter => join.copy(left = maybePushLimit(exp, left))
case FullOuter =>
join.copy(left = maybePushLimit(exp, left), right = maybePushLimit(exp, right))
case _ => join
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand All @@ -39,6 +39,8 @@ class LimitPushdownSuite extends PlanTest {

private val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
private val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
private val x = testRelation.subquery('x)
private val y = testRelation.subquery('y)

test("Union: limit to each side") {
val unionQuery = Union(testRelation, testRelation2).limit(1)
Expand Down Expand Up @@ -72,75 +74,26 @@ class LimitPushdownSuite extends PlanTest {
Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}
//
// test("limit: push down left outer join") {
// val x = testRelation.subquery('x)
// val y = testRelation.subquery('y)
//
// val originalQuery = {
// x.join(y, LeftOuter)
// .limit(1)
// }
//
// val optimized = Optimize.execute(originalQuery.analyze)
// val left = testRelation.limit(1)
// val correctAnswer =
// left.join(y, LeftOuter).limit(1).analyze
//
// comparePlans(optimized, correctAnswer)
// }
//
// test("limit: push down right outer join") {
// val x = testRelation.subquery('x)
// val y = testRelation.subquery('y)
//
// val originalQuery = {
// x.join(y, RightOuter)
// .limit(1)
// }
//
// val optimized = Optimize.execute(originalQuery.analyze)
// val right = testRelation.limit(1)
// val correctAnswer =
// x.join(right, RightOuter).limit(1).analyze
//
// comparePlans(optimized, correctAnswer)
// }
//
// test("limit: push down full outer join") {
// val x = testRelation.subquery('x)
// val y = testRelation.subquery('y)
//
// val originalQuery = {
// x.join(y, FullOuter)
// .limit(1)
// }
//
// val optimized = Optimize.execute(originalQuery.analyze)
// val left = testRelation
// val right = testRelation.limit(1)
// val correctAnswer =
// left.join(right, FullOuter).limit(1).analyze
//
// comparePlans(optimized, correctAnswer)
// }
//
// test("limit: push down full outer join + project") {
// val x = testRelation.subquery('x)
// val y = testRelation1.subquery('y)
//
// val originalQuery = {
// x.join(y, FullOuter).select('a, 'b, 'd)
// .limit(1)
// }
//
// val optimized = Optimize.execute(originalQuery.analyze)
// val left = testRelation.select('a, 'b)
// val right = testRelation1.limit(1)
// val correctAnswer =
// left.join(right, FullOuter).select('a, 'b, 'd).limit(1).analyze
//
// comparePlans(optimized, correctAnswer)
// }

test("push down left outer join") {
val originalQuery = x.join(y, LeftOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = GlobalLimit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("push down right outer join") {
val originalQuery = x.join(y, RightOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = GlobalLimit(1, x.join(LocalLimit(1, y), RightOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("push down full outer join") {
val originalQuery = x.join(y, FullOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = GlobalLimit(1, LocalLimit(1, x).join(LocalLimit(1, y), FullOuter)).analyze
comparePlans(optimized, correctAnswer)
}
}

0 comments on commit 00e7f39

Please sign in to comment.