From f9b59abeae16088c7c4d3a475762ef6c4ad42b4b Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Mon, 29 May 2017 12:21:34 +0200 Subject: [PATCH] [SPARK-20758][SQL] Add Constant propagation optimization ## What changes were proposed in this pull request? See class doc of `ConstantPropagation` for the approach used. ## How was this patch tested? - Added unit tests Author: Tejas Patil Closes #17993 from tejasapatil/SPARK-20758_const_propagation. --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/expressions.scala | 56 ++++++ .../optimizer/ConstantPropagationSuite.scala | 167 ++++++++++++++++++ .../datasources/FileSourceStrategySuite.scala | 18 +- 4 files changed, 235 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ae2f6bfa94ae7..d16689a34298a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -92,6 +92,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineUnions, // Constant folding and strength reduction NullPropagation(conf), + ConstantPropagation, FoldablePropagation, OptimizeIn(conf), ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8931eb2c8f3b1..51f749a8bf857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -54,6 +54,62 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding + * value in conjunctive [[Expression Expressions]] + * eg. + * {{{ + * SELECT * FROM table WHERE i = 5 AND j = i + 3 + * ==> SELECT * FROM table WHERE i = 5 AND j = 8 + * }}} + * + * Approach used: + * - Start from AND operator as the root + * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they + * don't have a `NOT` or `OR` operator in them + * - Populate a mapping of attribute => constant value by looking at all the equals predicates + * - Using this mapping, replace occurrence of the attributes with the corresponding constant values + * in the AND node. + */ +object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { + private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { + case _: Not | _: Or => true + case _ => false + }.isDefined + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f: Filter => f transformExpressionsUp { + case and: And => + val conjunctivePredicates = + splitConjunctivePredicates(and) + .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) + .filterNot(expr => containsNonConjunctionPredicates(expr)) + + val equalityPredicates = conjunctivePredicates.collect { + case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) + case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) + } + + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicates = equalityPredicates.map(_._2).toSet + + def replaceConstants(expression: Expression) = expression transform { + case a: AttributeReference => + constantsMap.get(a) match { + case Some(literal) => literal + case None => a + } + } + + and transform { + case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e) + case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e) + } + } + } +} /** * Reorder associative integral-type operators and fold all constants into one. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala new file mode 100644 index 0000000000000..81d2f3667e2d0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for constant propagation in expressions. + */ +class ConstantPropagationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("ConstantPropagation", FixedPoint(10), + ColumnPruning, + ConstantPropagation, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + private val columnA = 'a.int + private val columnB = 'b.int + private val columnC = 'c.int + + test("basic test") { + val query = testRelation + .select(columnA) + .where(columnA === Add(columnB, Literal(1)) && columnB === Literal(10)) + + val correctAnswer = + testRelation + .select(columnA) + .where(columnA === Literal(11) && columnB === Literal(10)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("with combination of AND and OR predicates") { + val query = testRelation + .select(columnA) + .where( + columnA === Add(columnB, Literal(1)) && + columnB === Literal(10) && + (columnA === Add(columnC, Literal(3)) || columnB === columnC)) + .analyze + + val correctAnswer = + testRelation + .select(columnA) + .where( + columnA === Literal(11) && + columnB === Literal(10) && + (Literal(11) === Add(columnC, Literal(3)) || Literal(10) === columnC)) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("equality predicates outside a `NOT` can be propagated within a `NOT`") { + val query = testRelation + .select(columnA) + .where(Not(columnA === Add(columnB, Literal(1))) && columnB === Literal(10)) + .analyze + + val correctAnswer = + testRelation + .select(columnA) + .where(Not(columnA === Literal(11)) && columnB === Literal(10)) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("equality predicates inside a `NOT` should not be picked for propagation") { + val query = testRelation + .select(columnA) + .where(Not(columnB === Literal(10)) && columnA === Add(columnB, Literal(1))) + .analyze + + comparePlans(Optimize.execute(query), query) + } + + test("equality predicates outside a `OR` can be propagated within a `OR`") { + val query = testRelation + .select(columnA) + .where( + columnA === Literal(2) && + (columnA === Add(columnB, Literal(3)) || columnB === Literal(9))) + .analyze + + val correctAnswer = testRelation + .select(columnA) + .where( + columnA === Literal(2) && + (Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9))) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("equality predicates inside a `OR` should not be picked for propagation") { + val query = testRelation + .select(columnA) + .where( + columnA === Add(columnB, Literal(2)) && + (columnA === Add(columnB, Literal(3)) || columnB === Literal(9))) + .analyze + + comparePlans(Optimize.execute(query), query) + } + + test("equality operator not immediate child of root `AND` should not be used for propagation") { + val query = testRelation + .select(columnA) + .where( + columnA === Literal(0) && + ((columnB === columnA) === (columnB === Literal(0)))) + .analyze + + val correctAnswer = testRelation + .select(columnA) + .where( + columnA === Literal(0) && + ((columnB === Literal(0)) === (columnB === Literal(0)))) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("conflicting equality predicates") { + val query = testRelation + .select(columnA) + .where( + columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) + + val correctAnswer = testRelation + .select(columnA) + .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)) + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index fa3c69612704d..9a2dcafb5e4b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -190,7 +190,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set.empty) // Only one file should be read. - checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 2")) { partitions => assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 1, "when checking files in partition 1") assert(partitions.head.files.head.partitionValues.getInt(0) == 1, @@ -217,7 +217,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set.empty) // Only one file should be read. - checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions => + checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 2")) { partitions => assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 1, "when checking files in partition 1") assert(partitions.head.files.head.partitionValues.getInt(0) == 1, @@ -235,13 +235,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi "p1=1/file1" -> 10, "p1=2/file2" -> 10)) - val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + val df1 = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") // Filter on data only are advisory so we have to reevaluate. - assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) - // Need to evalaute filters that are not pushed down. - assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + assert(getPhysicalFilters(df1) contains resolve(df1, "c1 = 1")) // Don't reevaluate partition only filters. - assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + assert(!(getPhysicalFilters(df1) contains resolve(df1, "p1 = 1"))) + + val df2 = table.where("(p1 + c2) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1")) + // Need to evalaute filters that are not pushed down. + assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2")) } test("bucketed table") {