Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23564][SQL] Add isNotNull check for left anti and outer joins #20717

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* An [[ExpressionSet]] that contains an additional set of constraints, such as equality
* constraints and `isNotNull` constraints, etc.
*/
lazy val allConstraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints)))
} else {
ExpressionSet(Set.empty)
}
}
lazy val allConstraints: ExpressionSet = ExpressionSet(constructAllConstraints)

/**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
Expand All @@ -55,6 +47,20 @@ trait QueryPlanConstraints { self: LogicalPlan =>
*/
protected def validConstraints: Set[Expression] = Set.empty

/**
* Returns the [[Expression]]s representing all the constraints which can be enforced on the
* current operator.
*/
protected def constructAllConstraints: Set[Expression] = {
if (conf.constraintPropagationEnabled) {
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints))
} else {
Set.empty
}
}

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
Expand All @@ -76,7 +82,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
protected def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,26 @@ case class Join(
case UsingJoin(_, _) => false
case _ => resolvedExceptNatural
}

override protected def constructAllConstraints: Set[Expression] = {
// additional constraints which are not enforced on the result of join operations, but can be
// enforced either on the left or the right side
val additionalConstraints = joinType match {
case LeftAnti | LeftOuter if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(right.outputSet))
case RightOuter if condition.isDefined =>
splitConjunctivePredicates(condition.get).flatMap(inferIsNotNullConstraints).filter(
_.references.subsetOf(left.outputSet))
case _ => Seq.empty[Expression]
}
super.constructAllConstraints ++ additionalConstraints
}

override lazy val constraints: ExpressionSet = ExpressionSet(
super.constructAllConstraints.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
})
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more test cases (or statements) for this code path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I added some statements to the ConstraintPropagationSuite.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private def testConstraintsAfterJoin(
x: LogicalPlan,
y: LogicalPlan,
expectedLeft: LogicalPlan,
expectedRight: LogicalPlan,
joinType: JoinType) = {
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, joinType, condition).analyze
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("filter: filter out constraints in condition") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val correctAnswer = testRelation
Expand Down Expand Up @@ -196,12 +209,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftSemi, condition).analyze
val left = x.where(IsNotNull('a))
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftSemi, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,46 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
test("propagating constraints in left-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1
val plan = tr1
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
.analyze
val expectedConstraints = ExpressionSet(Seq(
tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
verifyConstraints(plan.constraints, expectedConstraints)
verifyConstraints(plan.allConstraints, expectedConstraints +
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get))
}

test("propagating constraints in right-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1
val plan = tr1
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
.analyze
val expectedConstraints = ExpressionSet(Seq(
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
verifyConstraints(plan.constraints, expectedConstraints)
verifyConstraints(plan.allConstraints, expectedConstraints +
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))
}

test("propagating constraints in left-anti join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
val plan = tr1
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftAnti, Some("tr1.a".attr === "tr2.a".attr))
.analyze
val expectedConstraints = ExpressionSet(Seq(
tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
verifyConstraints(plan.constraints, expectedConstraints)
verifyConstraints(plan.allConstraints, expectedConstraints +
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get))
}

test("propagating constraints in full-outer join") {
Expand Down