Skip to content

Commit

Permalink
Should not infer the constraints that are trivially true
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Aug 8, 2017
1 parent d4e7f20 commit dc2112f
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 18 deletions.
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.util.OptimizerUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -72,26 +73,10 @@ object ConstantFolding extends Rule[LogicalPlan] {
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
case _: Not | _: Or => true
case _ => false
}.isDefined

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter => f transformExpressionsUp {
case and: And =>
val conjunctivePredicates =
splitConjunctivePredicates(and)
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
.filterNot(expr => containsNonConjunctionPredicates(expr))

val equalityPredicates = conjunctivePredicates.collect {
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
}

val equalityPredicates = getLiteralEqualityPredicates(splitConjunctivePredicates(and))
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet

Expand Down
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.OptimizerUtils._


trait QueryPlanConstraints { self: LogicalPlan =>
Expand Down Expand Up @@ -127,7 +128,29 @@ trait QueryPlanConstraints { self: LogicalPlan =>
})
case _ => // No inference
}
inferredConstraints -- constraints

val allConstraints = inferredConstraints ++ constraints
val additionalConstraints = inferredConstraints -- constraints

// Filters out meaningless constraints, e.g., given constraint `a = 1`, `b = 1`, `a = c`, and
// `b = c`, we first infer `a = b`. This constraint is trivially true, so we drop here.
// See SPARK-21652 for details.
val equalityPredicates =
AttributeMap(getLiteralEqualityPredicates(allConstraints.toSeq).map(_._1))
additionalConstraints.filterNot {
case b: BinaryComparison =>
(b.left, b.right) match {
case (l: Attribute, r: Attribute) =>
(equalityPredicates.get(l), equalityPredicates.get(r)) match {
case (Some(leftLiteral), Some(rightLiteral)) =>
b.withNewChildren(leftLiteral:: rightLiteral :: Nil).eval(EmptyRow)
.asInstanceOf[Boolean]
case _ => false
}
case _ => false
}
case _ => false
}
}

/**
Expand Down
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions._


/**
* Common utility methods used by Optimizer stuffs.
*/
object OptimizerUtils {

private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
case _: Not | _: Or => true
case _ => false
}.isDefined

def getLiteralEqualityPredicates(conjunctivePredicates: Seq[Expression])
: Seq[((AttributeReference, Literal), BinaryComparison)] = {
val conjunctiveEqualPredicates =
conjunctivePredicates
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
.filterNot(expr => containsNonConjunctionPredicates(expr))
conjunctiveEqualPredicates.collect {
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
}
}
}
Expand Up @@ -212,4 +212,21 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
}

test("SPARK-21652 Should not infer the constraints that are trivially true") {
val r1 = LocalRelation('r1a.int, 'r1b.int)
val r2 = LocalRelation('r2a.int)
val originalQuery = r1
.where('r1a === 1 && 'r1b === 1)
.join(r2, Inner, Some('r1a === 'r2a && 'r1b === 'r2a)).analyze
val optimized = Optimize.execute(originalQuery)
val correctAnswer = r1
.where(IsNotNull('r1a) && IsNotNull('r1b) && 'r1a === 1 && 'r1b === 1)
.join(
r2.where(IsNotNull('r2a) && 'r2a === 1),
Inner,
Some('r1a === 'r2a && 'r1b === 'r2a)
).analyze
comparePlans(optimized, correctAnswer)
}
}

0 comments on commit dc2112f

Please sign in to comment.