Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23564][SQL] Add isNotNull check for left anti and outer joins #20717

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,32 @@ trait Predicate extends Expression {
override def dataType: DataType = BooleanType
}

trait NotNullConstraintHelper {
/**
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
// Constraints always return true for all the inputs. That means, null will never be returned.
// Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
// null intolerant expressions.
case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}
}

trait PredicateHelper {
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ object CollapseWindow extends Rule[LogicalPlan] {
* Note: While this optimization is applicable to all types of join, it primarily benefits Inner and
* LeftSemi joins.
*/
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper
with NotNullConstraintHelper {

def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.constraintPropagationEnabled) {
Expand All @@ -663,7 +664,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
// right child
val constraints = join.allConstraints.filter { c =>
c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
}
} ++ extraJoinConstraints(join).toSet
// Remove those constraints that are already enforced by either the left or the right child
val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
val newConditionOpt = conditionOpt match {
Expand All @@ -675,6 +676,22 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
}
if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join
}

/**
* Returns additional constraints which are not enforced on the result of join operations, but
* which can be enforced either on the left or the right side
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not put this in Join.validConstraints? LogicalPlan.constraints should only contain constraints for the plab output, but LogicalPlan.allConstraints can contain more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't put it there, because constraints is created from allConstraints, so adding them to validConstraints could have caused them to be part of constraints too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see the problem. For left-anti join, although Join.output reuse the attributes from left child output, they are actually different attributes, e.g. Join may output null values, so we can't generate these constraints in Join.validConstraints.

I think we can override both allConstraints and constraints, to make sure these extra constraints appear in allConstraints, but not constraints.

*/
def extraJoinConstraints(join: Join): Seq[Expression] = {
join match {
case Join(_, right, LeftAnti | LeftOuter, condition) if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(right.outputSet))
case Join(left, _, RightOuter, condition) if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(left.outputSet))
case _ => Seq.empty[Expression]
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._


trait QueryPlanConstraints { self: LogicalPlan =>
trait QueryPlanConstraints extends NotNullConstraintHelper { self: LogicalPlan =>

/**
* An [[ExpressionSet]] that contains an additional set of constraints, such as equality
Expand Down Expand Up @@ -72,31 +72,6 @@ trait QueryPlanConstraints { self: LogicalPlan =>
isNotNullConstraints -- constraints
}

/**
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
// Constraints always return true for all the inputs. That means, null will never be returned.
// Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
// null intolerant expressions.
case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftAnti, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftAnti, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftOuter, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, RightOuter, condition).analyze
val left = x.where(IsNotNull('a))
val right = y
val correctAnswer = left.join(right, RightOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a simple repetition of the previous test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides", what about making helper test function and simplify these together at this time? Something like the following?

  private def testConstraints(
      x: LogicalPlan, y: LogicalPlan, left: LogicalPlan, right: LogicalPlan, joinType: JoinType) = {
    val condition = Some("x.a".attr === "y.a".attr)
    val originalQuery = x.join(y, joinType, condition).analyze
    val correctAnswer = left.join(right, joinType, condition).analyze
    val optimized = Optimize.execute(originalQuery)
    comparePlans(optimized, correctAnswer)
  }

  test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
  }

  test("SPARK-23564: left anti join should filter out null join keys on right side") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x, y.where(IsNotNull('a)), LeftAnti)
  }

  test("SPARK-23564: left outer join should filter out null join keys on right side") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x, y.where(IsNotNull('a)), LeftOuter)
  }

  test("SPARK-23564: right outer join should filter out null join keys on left side") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x.where(IsNotNull('a)), y, RightOuter)
  }

}