Skip to content

Commit

Permalink
[SPARK-23564][SQL] Add isNotNull check for left anti and outer joins
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Mar 2, 2018
1 parent 119f6a0 commit 45fbb85
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,32 @@ trait Predicate extends Expression {
override def dataType: DataType = BooleanType
}

trait NotNullConstraintHelper {
/**
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
// Constraints always return true for all the inputs. That means, null will never be returned.
// Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
// null intolerant expressions.
case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}
}

trait PredicateHelper {
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ object CollapseWindow extends Rule[LogicalPlan] {
* Note: While this optimization is applicable to all types of join, it primarily benefits Inner and
* LeftSemi joins.
*/
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper
with NotNullConstraintHelper {

def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.constraintPropagationEnabled) {
Expand All @@ -663,7 +664,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
// right child
val constraints = join.allConstraints.filter { c =>
c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
}
} ++ extraJoinConstraints(join).toSet
// Remove those constraints that are already enforced by either the left or the right child
val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
val newConditionOpt = conditionOpt match {
Expand All @@ -675,6 +676,22 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
}
if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join
}

/**
* Returns additional constraints which are not enforced on the result of join operations, but
* which can be enforced either on the left or the right side
*/
def extraJoinConstraints(join: Join): Seq[Expression] = {
join match {
case Join(_, right, LeftAnti | LeftOuter, condition) if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(right.outputSet))
case Join(left, _, RightOuter, condition) if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(left.outputSet))
case _ => Seq.empty[Expression]
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._


trait QueryPlanConstraints { self: LogicalPlan =>
trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan =>

/**
* An [[ExpressionSet]] that contains an additional set of constraints, such as equality
Expand Down Expand Up @@ -72,31 +72,6 @@ trait QueryPlanConstraints { self: LogicalPlan =>
isNotNullConstraints -- constraints
}

/**
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
// Constraints always return true for all the inputs. That means, null will never be returned.
// Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
// null intolerant expressions.
case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftAnti, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftAnti, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftOuter, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, RightOuter, condition).analyze
val left = x.where(IsNotNull('a))
val right = y
val correctAnswer = left.join(right, RightOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}

0 comments on commit 45fbb85

Please sign in to comment.