Skip to content

Commit

Permalink
[SPARK-27815][SQL] Predicate pushdown in one pass for cascading joins
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR makes the predicate pushdown logic in catalyst optimizer more efficient by unifying two existing rules `PushdownPredicates` and `PushPredicateThroughJoin`. Previously pushing down a predicate for queries such as `Filter(Join(Join(Join)))` requires n steps. This patch essentially reduces this to a single pass.

To make this actually work, we need to unify a few rules such as `CombineFilters`, `PushDownPredicate` and `PushDownPrdicateThroughJoin`. Otherwise cases such as `Filter(Join(Filter(Join)))` still requires several passes to fully push down predicates. This unification is done by composing several partial functions, which makes a minimal code change and can reuse existing UTs.

Results show that this optimization can improve the catalyst optimization time by 16.5%. For queries with more joins, the performance is even better. E.g., for TPC-DS q64, the performance boost is 49.2%.

## How was this patch tested?
Existing UTs + new a UT for the new rule.

Closes #24956 from yeshengm/fixed-point-opt.

Authored-by: Yesheng Ma <kimi.ysma@gmail.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information
yeshengm authored and gatorsmile committed Jul 3, 2019
1 parent 70b1a10 commit 74f1176
Show file tree
Hide file tree
Showing 17 changed files with 235 additions and 34 deletions.
Expand Up @@ -63,8 +63,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
PushProjectionThroughUnion,
ReorderJoin,
EliminateOuterJoin,
PushPredicateThroughJoin,
PushDownPredicate,
PushDownPredicates,
PushDownLeftSemiAntiJoin,
PushLeftSemiLeftAntiThroughJoin,
LimitPushDown,
Expand Down Expand Up @@ -911,7 +910,9 @@ object CombineUnions extends Rule[LogicalPlan] {
* one conjunctive predicate.
*/
object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
// The query execution/optimization does not guarantee the expressions are evaluated in order.
// We only can combine them if and only if both are deterministic.
case Filter(fc, nf @ Filter(nc, grandChild)) if fc.deterministic && nc.deterministic =>
Expand Down Expand Up @@ -996,15 +997,30 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* The unified version for predicate pushdown of normal operators and joins.
* This rule improves performance of predicate pushdown for cascading joins such as:
* Filter-Join-Join-Join. Most predicates can be pushed down in a single pass.
*/
object PushDownPredicates extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
CombineFilters.applyLocally
.orElse(PushPredicateThroughNonJoin.applyLocally)
.orElse(PushPredicateThroughJoin.applyLocally)
}
}

/**
* Pushes [[Filter]] operators through many operators iff:
* 1) the operator is deterministic
* 2) the predicate is deterministic and the operator will not change any of rows.
*
* This heuristic is valid assuming the expression evaluation cost is minimal.
*/
object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
// implies that, for a given input row, the output are determined by the expression's initial
Expand Down Expand Up @@ -1221,7 +1237,9 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
// push the where condition down into join filter
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, hint)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
Expand Down
Expand Up @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule

/**
* This rule is a variant of [[PushDownPredicate]] which can handle
* This rule is a variant of [[PushPredicateThroughNonJoin]] which can handle
* pushing down Left semi and Left Anti joins below the following operators.
* 1) Project
* 2) Window
* 3) Union
* 4) Aggregate
* 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]].
* 5) Other permissible unary operators. please see [[PushPredicateThroughNonJoin.canPushThrough]].
*/
object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -42,7 +42,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
// No join condition, just push down the Join below Project
p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
} else {
val aliasMap = PushDownPredicate.getAliasMap(p)
val aliasMap = PushPredicateThroughNonJoin.getAliasMap(p)
val newJoinCond = if (aliasMap.nonEmpty) {
Option(replaceAlias(joinCond.get, aliasMap))
} else {
Expand All @@ -55,7 +55,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
val aliasMap = PushDownPredicate.getAliasMap(agg)
val aliasMap = PushPredicateThroughNonJoin.getAliasMap(agg)
val canPushDownPredicate = (predicate: Expression) => {
val replaced = replaceAlias(predicate, aliasMap)
predicate.references.nonEmpty &&
Expand Down Expand Up @@ -94,7 +94,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {

// LeftSemi/LeftAnti over UnaryNode
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(_), _, _)
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
if PushPredicateThroughNonJoin.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
val validAttrs = u.child.outputSet ++ rightOp.outputSet
pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And))
}
Expand Down
Expand Up @@ -32,7 +32,7 @@ class ColumnPruningSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
PushDownPredicate,
PushPredicateThroughNonJoin,
ColumnPruning,
RemoveNoopOperators,
CollapseProject) :: Nil
Expand Down
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

/**
* This test suite ensures that the [[PushDownPredicates]] actually does predicate pushdown in
* an efficient manner. This is enforced by asserting that a single predicate pushdown can push
* all predicate to bottom as much as possible.
*/
class FilterPushdownOnePassSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
// this batch must reach expected state in one pass
Batch("Filter Pushdown One Pass", Once,
ReorderJoin,
PushDownPredicates
) :: Nil
}

val testRelation1 = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('a.int, 'd.int, 'e.int)

test("really simple predicate push down") {
val x = testRelation1.subquery('x)
val y = testRelation2.subquery('y)

val originalQuery = x.join(y).where("x.a".attr === 1)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = x.where("x.a".attr === 1).join(y).analyze

comparePlans(optimized, correctAnswer)
}

test("push down conjunctive predicates") {
val x = testRelation1.subquery('x)
val y = testRelation2.subquery('y)

val originalQuery = x.join(y).where("x.a".attr === 1 && "y.d".attr < 1)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = x.where("x.a".attr === 1).join(y.where("y.d".attr < 1)).analyze

comparePlans(optimized, correctAnswer)
}

test("push down predicates for simple joins") {
val x = testRelation1.subquery('x)
val y = testRelation2.subquery('y)

val originalQuery =
x.where("x.c".attr < 0)
.join(y.where("y.d".attr > 1))
.where("x.a".attr === 1 && "y.d".attr < 2)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
x.where("x.c".attr < 0 && "x.a".attr === 1)
.join(y.where("y.d".attr > 1 && "y.d".attr < 2)).analyze

comparePlans(optimized, correctAnswer)
}

test("push down top-level filters for cascading joins") {
val x = testRelation1.subquery('x)
val y = testRelation2.subquery('y)

val originalQuery =
y.join(x).join(x).join(x).join(x).join(x).where("y.d".attr === 0)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = y.where("y.d".attr === 0).join(x).join(x).join(x).join(x).join(x).analyze

comparePlans(optimized, correctAnswer)
}

test("push down predicates for tree-like joins") {
val x = testRelation1.subquery('x)
val y1 = testRelation2.subquery('y1)
val y2 = testRelation2.subquery('y2)

val originalQuery =
y1.join(x).join(x)
.join(y2.join(x).join(x))
.where("y1.d".attr === 0 && "y2.d".attr === 3)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
y1.where("y1.d".attr === 0).join(x).join(x)
.join(y2.where("y2.d".attr === 3).join(x).join(x)).analyze

comparePlans(optimized, correctAnswer)
}

test("push down through join and project") {
val x = testRelation1.subquery('x)
val y = testRelation2.subquery('y)

val originalQuery =
x.where('a > 0).select('a, 'b)
.join(y.where('d < 100).select('e))
.where("x.a".attr < 100)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
x.where('a > 0 && 'a < 100).select('a, 'b)
.join(y.where('d < 100).select('e)).analyze

comparePlans(optimized, correctAnswer)
}

test("push down through deep projects") {
val x = testRelation1.subquery('x)

val originalQuery =
x.select(('a + 1) as 'a1, 'b)
.select(('a1 + 1) as 'a2, 'b)
.select(('a2 + 1) as 'a3, 'b)
.select(('a3 + 1) as 'a4, 'b)
.select('b)
.where('b > 0)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
x.where('b > 0)
.select(('a + 1) as 'a1, 'b)
.select(('a1 + 1) as 'a2, 'b)
.select(('a2 + 1) as 'a3, 'b)
.select(('a3 + 1) as 'a4, 'b)
.select('b).analyze

comparePlans(optimized, correctAnswer)
}

test("push down through aggregate and join") {
val x = testRelation1.subquery('x)
val y = testRelation2.subquery('y)

val left = x
.where('c > 0)
.groupBy('a)('a, count('b))
.subquery('left)
val right = y
.where('d < 0)
.groupBy('a)('a, count('d))
.subquery('right)
val originalQuery = left
.join(right).where("left.a".attr < 100 && "right.a".attr < 100)

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
x.where('c > 0 && 'a < 100).groupBy('a)('a, count('b))
.join(y.where('d < 0 && 'a < 100).groupBy('a)('a, count('d)))
.analyze

comparePlans(optimized, correctAnswer)
}
}
Expand Up @@ -35,7 +35,7 @@ class FilterPushdownSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(10),
CombineFilters,
PushDownPredicate,
PushPredicateThroughNonJoin,
BooleanSimplification,
PushPredicateThroughJoin,
CollapseProject) :: Nil
Expand Down
Expand Up @@ -31,7 +31,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val batches =
Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
PushPredicateThroughNonJoin,
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
Expand Down
Expand Up @@ -34,7 +34,7 @@ class JoinOptimizationSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
PushDownPredicate,
PushPredicateThroughNonJoin,
BooleanSimplification,
ReorderJoin,
PushPredicateThroughJoin,
Expand Down
Expand Up @@ -35,7 +35,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
EliminateResolvedHint) ::
Batch("Operator Optimizations", FixedPoint(100),
CombineFilters,
PushDownPredicate,
PushPredicateThroughNonJoin,
ReorderJoin,
PushPredicateThroughJoin,
ColumnPruning,
Expand Down
Expand Up @@ -35,7 +35,7 @@ class LeftSemiPushdownSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(10),
CombineFilters,
PushDownPredicate,
PushPredicateThroughNonJoin,
PushDownLeftSemiAntiJoin,
PushLeftSemiLeftAntiThroughJoin,
BooleanSimplification,
Expand Down
Expand Up @@ -34,7 +34,7 @@ class OptimizerLoggingSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Optimizer Batch", FixedPoint(100),
PushDownPredicate, ColumnPruning, CollapseProject) ::
PushPredicateThroughNonJoin, ColumnPruning, CollapseProject) ::
Batch("Batch Has No Effect", Once,
ColumnPruning) :: Nil
}
Expand Down Expand Up @@ -99,7 +99,7 @@ class OptimizerLoggingSuite extends PlanTest {
verifyLog(
level._2,
Seq(
PushDownPredicate.ruleName,
PushPredicateThroughNonJoin.ruleName,
ColumnPruning.ruleName,
CollapseProject.ruleName))
}
Expand All @@ -123,15 +123,15 @@ class OptimizerLoggingSuite extends PlanTest {

test("test log rules") {
val rulesSeq = Seq(
Seq(PushDownPredicate.ruleName,
Seq(PushPredicateThroughNonJoin.ruleName,
ColumnPruning.ruleName,
CollapseProject.ruleName).reduce(_ + "," + _) ->
Seq(PushDownPredicate.ruleName,
Seq(PushPredicateThroughNonJoin.ruleName,
ColumnPruning.ruleName,
CollapseProject.ruleName),
Seq(PushDownPredicate.ruleName,
Seq(PushPredicateThroughNonJoin.ruleName,
ColumnPruning.ruleName).reduce(_ + "," + _) ->
Seq(PushDownPredicate.ruleName,
Seq(PushPredicateThroughNonJoin.ruleName,
ColumnPruning.ruleName),
CollapseProject.ruleName ->
Seq(CollapseProject.ruleName),
Expand Down

0 comments on commit 74f1176

Please sign in to comment.