Skip to content

Commit

Permalink
added more tests, refactored existing ones, made back private method
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Mar 23, 2018
1 parent 9e2d993 commit 5cadd86
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private def testConstraintsAfterJoin(
x: LogicalPlan,
y: LogicalPlan,
expectedLeft: LogicalPlan,
expectedRight: LogicalPlan,
joinType: JoinType) = {
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, joinType, condition).analyze
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("filter: filter out constraints in condition") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val correctAnswer = testRelation
Expand Down Expand Up @@ -196,48 +209,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftSemi, condition).analyze
val left = x.where(IsNotNull('a))
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftSemi, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftAnti, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftAnti, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftOuter, condition).analyze
val left = x
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, RightOuter, condition).analyze
val left = x.where(IsNotNull('a))
val right = y
val correctAnswer = left.join(right, RightOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,46 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
test("propagating constraints in left-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1
val plan = tr1
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
.analyze
val expectedConstraints = ExpressionSet(Seq(
tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
verifyConstraints(plan.constraints, expectedConstraints)
verifyConstraints(plan.allConstraints, expectedConstraints +
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get))
}

test("propagating constraints in right-outer join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
verifyConstraints(tr1
val plan = tr1
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
.analyze
val expectedConstraints = ExpressionSet(Seq(
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
verifyConstraints(plan.constraints, expectedConstraints)
verifyConstraints(plan.allConstraints, expectedConstraints +
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))
}

test("propagating constraints in left-anti join") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
val plan = tr1
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftAnti, Some("tr1.a".attr === "tr2.a".attr))
.analyze
val expectedConstraints = ExpressionSet(Seq(
tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
verifyConstraints(plan.constraints, expectedConstraints)
verifyConstraints(plan.allConstraints, expectedConstraints +
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get))
}

test("propagating constraints in full-outer join") {
Expand Down

0 comments on commit 5cadd86

Please sign in to comment.