Skip to content

Commit

Permalink
[SPARK-20758][SQL] Add Constant propagation optimization
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

See class doc of `ConstantPropagation` for the approach used.

## How was this patch tested?

- Added unit tests

Author: Tejas Patil <tejasp@fb.com>

Closes #17993 from tejasapatil/SPARK-20758_const_propagation.
  • Loading branch information
tejasapatil authored and hvanhovell committed May 29, 2017
1 parent 9d0db5a commit f9b59ab
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CombineUnions,
// Constant folding and strength reduction
NullPropagation(conf),
ConstantPropagation,
FoldablePropagation,
OptimizeIn(conf),
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,62 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
}

/**
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
* value in conjunctive [[Expression Expressions]]
* eg.
* {{{
* SELECT * FROM table WHERE i = 5 AND j = i + 3
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
* }}}
*
* Approach used:
* - Start from AND operator as the root
* - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
* don't have a `NOT` or `OR` operator in them
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
case _: Not | _: Or => true
case _ => false
}.isDefined

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter => f transformExpressionsUp {
case and: And =>
val conjunctivePredicates =
splitConjunctivePredicates(and)
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
.filterNot(expr => containsNonConjunctionPredicates(expr))

val equalityPredicates = conjunctivePredicates.collect {
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
}

val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet

def replaceConstants(expression: Expression) = expression transform {
case a: AttributeReference =>
constantsMap.get(a) match {
case Some(literal) => literal
case None => a
}
}

and transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
}
}
}
}

/**
* Reorder associative integral-type operators and fold all constants into one.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

/**
* Unit tests for constant propagation in expressions.
*/
class ConstantPropagationSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("ConstantPropagation", FixedPoint(10),
ColumnPruning,
ConstantPropagation,
ConstantFolding,
BooleanSimplification) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private val columnA = 'a.int
private val columnB = 'b.int
private val columnC = 'c.int

test("basic test") {
val query = testRelation
.select(columnA)
.where(columnA === Add(columnB, Literal(1)) && columnB === Literal(10))

val correctAnswer =
testRelation
.select(columnA)
.where(columnA === Literal(11) && columnB === Literal(10)).analyze

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}

test("with combination of AND and OR predicates") {
val query = testRelation
.select(columnA)
.where(
columnA === Add(columnB, Literal(1)) &&
columnB === Literal(10) &&
(columnA === Add(columnC, Literal(3)) || columnB === columnC))
.analyze

val correctAnswer =
testRelation
.select(columnA)
.where(
columnA === Literal(11) &&
columnB === Literal(10) &&
(Literal(11) === Add(columnC, Literal(3)) || Literal(10) === columnC))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("equality predicates outside a `NOT` can be propagated within a `NOT`") {
val query = testRelation
.select(columnA)
.where(Not(columnA === Add(columnB, Literal(1))) && columnB === Literal(10))
.analyze

val correctAnswer =
testRelation
.select(columnA)
.where(Not(columnA === Literal(11)) && columnB === Literal(10))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("equality predicates inside a `NOT` should not be picked for propagation") {
val query = testRelation
.select(columnA)
.where(Not(columnB === Literal(10)) && columnA === Add(columnB, Literal(1)))
.analyze

comparePlans(Optimize.execute(query), query)
}

test("equality predicates outside a `OR` can be propagated within a `OR`") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(2) &&
(columnA === Add(columnB, Literal(3)) || columnB === Literal(9)))
.analyze

val correctAnswer = testRelation
.select(columnA)
.where(
columnA === Literal(2) &&
(Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9)))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("equality predicates inside a `OR` should not be picked for propagation") {
val query = testRelation
.select(columnA)
.where(
columnA === Add(columnB, Literal(2)) &&
(columnA === Add(columnB, Literal(3)) || columnB === Literal(9)))
.analyze

comparePlans(Optimize.execute(query), query)
}

test("equality operator not immediate child of root `AND` should not be used for propagation") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(0) &&
((columnB === columnA) === (columnB === Literal(0))))
.analyze

val correctAnswer = testRelation
.select(columnA)
.where(
columnA === Literal(0) &&
((columnB === Literal(0)) === (columnB === Literal(0))))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("conflicting equality predicates") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))

val correctAnswer = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5))

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
checkDataFilters(Set.empty)

// Only one file should be read.
checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions =>
checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 2")) { partitions =>
assert(partitions.size == 1, "when checking partitions")
assert(partitions.head.files.size == 1, "when checking files in partition 1")
assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
Expand All @@ -217,7 +217,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
checkDataFilters(Set.empty)

// Only one file should be read.
checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions =>
checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 2")) { partitions =>
assert(partitions.size == 1, "when checking partitions")
assert(partitions.head.files.size == 1, "when checking files in partition 1")
assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
Expand All @@ -235,13 +235,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
"p1=1/file1" -> 10,
"p1=2/file2" -> 10))

val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
val df1 = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
// Filter on data only are advisory so we have to reevaluate.
assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1"))
// Need to evalaute filters that are not pushed down.
assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2"))
assert(getPhysicalFilters(df1) contains resolve(df1, "c1 = 1"))
// Don't reevaluate partition only filters.
assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1")))
assert(!(getPhysicalFilters(df1) contains resolve(df1, "p1 = 1")))

val df2 = table.where("(p1 + c2) = 2 AND c1 = 1")
// Filter on data only are advisory so we have to reevaluate.
assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1"))
// Need to evalaute filters that are not pushed down.
assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2"))
}

test("bucketed table") {
Expand Down

0 comments on commit f9b59ab

Please sign in to comment.