From 45fbb851e76eeaa45c9926571059274efca2441a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 2 Mar 2018 17:27:18 +0100 Subject: [PATCH] [SPARK-23564][SQL] Add isNotNull check for left anti and outer joins --- .../sql/catalyst/expressions/predicates.scala | 26 ++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++++-- .../plans/logical/QueryPlanConstraints.scala | 27 +------------- .../InferFiltersFromConstraintsSuite.scala | 36 +++++++++++++++++++ 4 files changed, 82 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d00d4..398d06122c591 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -45,6 +45,32 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } +trait NotNullConstraintHelper { + /** + * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions + * of constraints. + */ + protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + constraint match { + // When the root is IsNotNull, we can push IsNotNull through the child null intolerant + // expressions + case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + // Constraints always return true for all the inputs. That means, null will never be returned. + // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child + // null intolerant expressions. + case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + } + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case _ => Seq.empty[Attribute] + } +} trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..72eaffb25fd12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -638,7 +638,8 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper + with NotNullConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -663,7 +664,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe // right child val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } + } ++ extraJoinConstraints(join).toSet // Remove those constraints that are already enforced by either the left or the right child val additionalConstraints = constraints -- (left.constraints ++ right.constraints) val newConditionOpt = conditionOpt match { @@ -675,6 +676,22 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join } + + /** + * Returns additional constraints which are not enforced on the result of join operations, but + * which can be enforced either on the left or the right side + */ + def extraJoinConstraints(join: Join): Seq[Expression] = { + join match { + case Join(_, right, LeftAnti | LeftOuter, condition) if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(right.outputSet)) + case Join(left, _, RightOuter, condition) if condition.isDefined => + splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter( + _.references.subsetOf(left.outputSet)) + case _ => Seq.empty[Expression] + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..a82dc7be509fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan => /** * An [[ExpressionSet]] that contains an additional set of constraints, such as equality @@ -72,31 +72,6 @@ trait QueryPlanConstraints { self: LogicalPlan => isNotNullConstraints -- constraints } - /** - * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions - * of constraints. - */ - private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = - constraint match { - // When the root is IsNotNull, we can push IsNotNull through the child null intolerant - // expressions - case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) - // Constraints always return true for all the inputs. That means, null will never be returned. - // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child - // null intolerant expressions. - case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) - } - - /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. - */ - private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { - case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) - case _ => Seq.empty[Attribute] - } - /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..4ad6ad02108c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftAnti, condition).analyze + val left = x + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftAnti, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, RightOuter, condition).analyze + val left = x.where(IsNotNull('a)) + val right = y + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } }