From dc2112f6a26125cd6a67eac79cef91751ac639f8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 8 Aug 2017 20:08:38 +0900 Subject: [PATCH] Should not infer the constraints that are trivially true --- .../sql/catalyst/optimizer/expressions.scala | 19 +------- .../plans/logical/QueryPlanConstraints.scala | 25 +++++++++- .../sql/catalyst/util/OptimizerUtils.scala | 46 +++++++++++++++++++ .../InferFiltersFromConstraintsSuite.scala | 17 +++++++ 4 files changed, 89 insertions(+), 18 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/OptimizerUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 79a6c8663a56b..21a5b89d9fec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.util.OptimizerUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -72,26 +73,10 @@ object ConstantFolding extends Rule[LogicalPlan] { * in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { - private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { - case _: Not | _: Or => true - case _ => false - }.isDefined - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => f transformExpressionsUp { case and: And => - val conjunctivePredicates = - splitConjunctivePredicates(and) - .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) - .filterNot(expr => containsNonConjunctionPredicates(expr)) - - val equalityPredicates = conjunctivePredicates.collect { - case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) - } - + val equalityPredicates = getLiteralEqualityPredicates(splitConjunctivePredicates(and)) val constantsMap = AttributeMap(equalityPredicates.map(_._1)) val predicates = equalityPredicates.map(_._2).toSet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 8bffbd0c208cb..3a02fb3f73ca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.OptimizerUtils._ trait QueryPlanConstraints { self: LogicalPlan => @@ -127,7 +128,29 @@ trait QueryPlanConstraints { self: LogicalPlan => }) case _ => // No inference } - inferredConstraints -- constraints + + val allConstraints = inferredConstraints ++ constraints + val additionalConstraints = inferredConstraints -- constraints + + // Filters out meaningless constraints, e.g., given constraint `a = 1`, `b = 1`, `a = c`, and + // `b = c`, we first infer `a = b`. This constraint is trivially true, so we drop here. + // See SPARK-21652 for details. + val equalityPredicates = + AttributeMap(getLiteralEqualityPredicates(allConstraints.toSeq).map(_._1)) + additionalConstraints.filterNot { + case b: BinaryComparison => + (b.left, b.right) match { + case (l: Attribute, r: Attribute) => + (equalityPredicates.get(l), equalityPredicates.get(r)) match { + case (Some(leftLiteral), Some(rightLiteral)) => + b.withNewChildren(leftLiteral:: rightLiteral :: Nil).eval(EmptyRow) + .asInstanceOf[Boolean] + case _ => false + } + case _ => false + } + case _ => false + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/OptimizerUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/OptimizerUtils.scala new file mode 100644 index 0000000000000..7478b94f6526d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/OptimizerUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * Common utility methods used by Optimizer stuffs. + */ +object OptimizerUtils { + + private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { + case _: Not | _: Or => true + case _ => false + }.isDefined + + def getLiteralEqualityPredicates(conjunctivePredicates: Seq[Expression]) + : Seq[((AttributeReference, Literal), BinaryComparison)] = { + val conjunctiveEqualPredicates = + conjunctivePredicates + .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) + .filterNot(expr => containsNonConjunctionPredicates(expr)) + conjunctiveEqualPredicates.collect { + case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) + case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index d2dd469e2d74f..fa0c213d7f137 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -212,4 +212,21 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("SPARK-21652 Should not infer the constraints that are trivially true") { + val r1 = LocalRelation('r1a.int, 'r1b.int) + val r2 = LocalRelation('r2a.int) + val originalQuery = r1 + .where('r1a === 1 && 'r1b === 1) + .join(r2, Inner, Some('r1a === 'r2a && 'r1b === 'r2a)).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = r1 + .where(IsNotNull('r1a) && IsNotNull('r1b) && 'r1a === 1 && 'r1b === 1) + .join( + r2.where(IsNotNull('r2a) && 'r2a === 1), + Inner, + Some('r1a === 'r2a && 'r1b === 'r2a) + ).analyze + comparePlans(optimized, correctAnswer) + } }