Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23564][SQL] Add isNotNull check for left anti and outer joins #20717

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* An [[ExpressionSet]] that contains an additional set of constraints, such as equality
* constraints and `isNotNull` constraints, etc.
*/
lazy val allConstraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints)))
} else {
ExpressionSet(Set.empty)
}
}
lazy val allConstraints: ExpressionSet = ExpressionSet(constructAllConstraints)

/**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
Expand All @@ -55,6 +47,20 @@ trait QueryPlanConstraints { self: LogicalPlan =>
*/
protected def validConstraints: Set[Expression] = Set.empty

/**
* Returns the [[Expression]]s representing all the constraints which can be enforced on the
* current operator.
*/
protected def constructAllConstraints: Set[Expression] = {
if (conf.constraintPropagationEnabled) {
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints))
} else {
Set.empty
}
}

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
Expand All @@ -76,7 +82,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
Expand All @@ -91,7 +97,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let keep this private because this is used only in this class.

case a: Attribute => Seq(a)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,26 @@ case class Join(
case UsingJoin(_, _) => false
case _ => resolvedExceptNatural
}

override protected def constructAllConstraints: Set[Expression] = {
// additional constraints which are not enforced on the result of join operations, but can be
// enforced either on the left or the right side
val additionalConstraints = joinType match {
case LeftAnti | LeftOuter if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(right.outputSet))
case RightOuter if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(left.outputSet))
case _ => Seq.empty[Expression]
}
super.constructAllConstraints ++ additionalConstraints
}

override lazy val constraints: ExpressionSet = ExpressionSet(
super.constructAllConstraints.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
})
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more test cases (or statements) for this code path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I added some statements to the ConstraintPropagationSuite.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftAnti, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftAnti, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftOuter, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, RightOuter, condition).analyze
val left = x.where(IsNotNull('a))
val right = y
val correctAnswer = left.join(right, RightOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a simple repetition of the previous test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides", what about making helper test function and simplify these together at this time? Something like the following?

  private def testConstraints(
      x: LogicalPlan, y: LogicalPlan, left: LogicalPlan, right: LogicalPlan, joinType: JoinType) = {
    val condition = Some("x.a".attr === "y.a".attr)
    val originalQuery = x.join(y, joinType, condition).analyze
    val correctAnswer = left.join(right, joinType, condition).analyze
    val optimized = Optimize.execute(originalQuery)
    comparePlans(optimized, correctAnswer)
  }

  test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
  }

  test("SPARK-23564: left anti join should filter out null join keys on right side") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x, y.where(IsNotNull('a)), LeftAnti)
  }

  test("SPARK-23564: left outer join should filter out null join keys on right side") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x, y.where(IsNotNull('a)), LeftOuter)
  }

  test("SPARK-23564: right outer join should filter out null join keys on left side") {
    val x = testRelation.subquery('x)
    val y = testRelation.subquery('y)
    testConstraints(x, y, x.where(IsNotNull('a)), y, RightOuter)
  }

}