From 5cadd86ec4fae40c8d2606f0c00aed99a96d0027 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 23 Mar 2018 15:23:33 +0100 Subject: [PATCH] added more tests, refactored existing ones, made back private method --- .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 45 +++++++------------ .../plans/ConstraintPropagationSuite.scala | 39 ++++++++++++---- 3 files changed, 49 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 857b4d4060618..219231dc792cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -97,7 +97,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * Recursively explores the expressions which are null intolerant and returns all attributes * in these expressions. */ - protected def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { case a: Attribute => Seq(a) case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 4ad6ad02108c7..04d5af603ef1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -40,6 +40,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,48 +209,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-23564: left anti join should filter out null join keys on right side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftAnti, condition).analyze - val left = x - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftAnti, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) } test("SPARK-23564: left outer join should filter out null join keys on right side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) } test("SPARK-23564: right outer join should filter out null join keys on left side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, RightOuter, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y - val correctAnswer = left.join(right, RightOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d922642..b19f5a7fde4ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -237,23 +237,46 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { test("propagating constraints in left-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 + val plan = tr1 .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get)) } test("propagating constraints in right-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 + val plan = tr1 .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, - IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)) + } + + test("propagating constraints in left-anti join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val plan = tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftAnti, Some("tr1.a".attr === "tr2.a".attr)) + .analyze + val expectedConstraints = ExpressionSet(Seq( + tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + verifyConstraints(plan.constraints, expectedConstraints) + verifyConstraints(plan.allConstraints, expectedConstraints + + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get)) } test("propagating constraints in full-outer join") {