-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-41162][SQL] Fix anti- and semi-join for self-join with aggregations #39131
Changes from all commits
eebeb34
76dcd1e
ade61b8
cc28b21
2ae3a41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._ | |
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types.IntegerType | ||
|
||
class LeftSemiPushdownSuite extends PlanTest { | ||
class LeftSemiAntiJoinPushDownSuite extends PlanTest { | ||
|
||
object Optimize extends RuleExecutor[LogicalPlan] { | ||
val batches = | ||
|
@@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
val testRelation1 = LocalRelation($"d".int) | ||
val testRelation2 = LocalRelation($"e".int) | ||
|
||
test("Project: LeftSemiAnti join pushdown") { | ||
test("Project: LeftSemi join pushdown") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These change to test names are necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The term |
||
val originalQuery = testRelation | ||
.select(star()) | ||
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
@@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { | ||
test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") { | ||
val originalQuery = testRelation | ||
.select(Rand(1), $"b", $"c") | ||
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
@@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, originalQuery.analyze) | ||
} | ||
|
||
test("Project: LeftSemiAnti join non correlated scalar subq") { | ||
test("Project: LeftSemi join pushdown - non-correlated scalar subq") { | ||
val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze) | ||
val originalQuery = testRelation | ||
.select(subq.as("sum")) | ||
|
@@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { | ||
test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") { | ||
val testRelation2 = LocalRelation($"e".int, $"f".int) | ||
val subqPlan = testRelation2.groupBy($"e")(sum($"f").as("sum")).where($"e" === $"a") | ||
val subqExpr = ScalarSubquery(subqPlan) | ||
|
@@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, originalQuery.analyze) | ||
} | ||
|
||
test("Aggregate: LeftSemiAnti join pushdown") { | ||
test("Aggregate: LeftSemi join pushdown") { | ||
val originalQuery = testRelation | ||
.groupBy($"b")($"b", sum($"c")) | ||
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
@@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { | ||
test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") { | ||
val originalQuery = testRelation | ||
.groupBy($"b")($"b", Rand(10).as("c")) | ||
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d")) | ||
|
@@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, originalQuery.analyze) | ||
} | ||
|
||
test("LeftSemiAnti join over aggregate - no pushdown") { | ||
test("Aggregate: LeftSemi join no pushdown") { | ||
val originalQuery = testRelation | ||
.groupBy($"b")($"b", sum($"c").as("sum")) | ||
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d" && $"sum" === $"d")) | ||
|
@@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, originalQuery.analyze) | ||
} | ||
|
||
test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { | ||
test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") { | ||
val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze) | ||
val originalQuery = testRelation | ||
.groupBy($"a") ($"a", subq.as("sum")) | ||
|
@@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("LeftSemiAnti join over Window") { | ||
test("Window: LeftSemi join pushdown") { | ||
val winExpr = windowExpr(count($"b"), | ||
windowSpec($"a" :: Nil, $"b".asc :: Nil, UnspecifiedFrame)) | ||
|
||
|
@@ -185,7 +185,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Window: LeftSemi partial pushdown") { | ||
test("Window: LeftSemi join partial pushdown") { | ||
// Attributes from join condition which does not refer to the window partition spec | ||
// are kept up in the plan as a Filter operator above Window. | ||
val winExpr = windowExpr(count($"b"), | ||
|
@@ -227,7 +227,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Union: LeftSemiAnti join pushdown") { | ||
test("Union: LeftSemi join pushdown") { | ||
val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int) | ||
|
||
val originalQuery = Union(Seq(testRelation, testRelation2)) | ||
|
@@ -243,7 +243,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Union: LeftSemiAnti join pushdown in self join scenario") { | ||
test("Union: LeftSemi join pushdown in self join scenario") { | ||
val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int) | ||
val attrX = testRelation2.output.head | ||
|
||
|
@@ -262,7 +262,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Unary: LeftSemiAnti join pushdown") { | ||
test("Unary: LeftSemi join pushdown") { | ||
val originalQuery = testRelation | ||
.select(star()) | ||
.repartition(1) | ||
|
@@ -277,7 +277,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Unary: LeftSemiAnti join pushdown - empty join condition") { | ||
test("Unary: LeftSemi join pushdown - empty join condition") { | ||
val originalQuery = testRelation | ||
.select(star()) | ||
.repartition(1) | ||
|
@@ -292,7 +292,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Unary: LeftSemi join pushdown - partial pushdown") { | ||
test("Unary: LeftSemi join partial pushdown") { | ||
val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) | ||
val originalQuery = testRelationWithArrayType | ||
.generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) | ||
|
@@ -309,7 +309,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, correctAnswer) | ||
} | ||
|
||
test("Unary: LeftAnti join pushdown - no pushdown") { | ||
test("Unary: LeftAnti join no pushdown") { | ||
val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) | ||
val originalQuery = testRelationWithArrayType | ||
.generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) | ||
|
@@ -320,7 +320,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, originalQuery.analyze) | ||
} | ||
|
||
test("Unary: LeftSemiAnti join pushdown - no pushdown") { | ||
test("Unary: LeftSemi join - no pushdown") { | ||
val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType)) | ||
val originalQuery = testRelationWithArrayType | ||
.generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col")) | ||
|
@@ -331,7 +331,7 @@ class LeftSemiPushdownSuite extends PlanTest { | |
comparePlans(optimized, originalQuery.analyze) | ||
} | ||
|
||
test("Unary: LeftSemi join push down through Expand") { | ||
test("Unary: LeftSemi join pushdown through Expand") { | ||
val expand = Expand(Seq(Seq($"a", $"b", "null"), Seq($"a", "null", $"c")), | ||
Seq($"a", $"b", $"c"), testRelation) | ||
val originalQuery = expand | ||
|
@@ -437,6 +437,25 @@ class LeftSemiPushdownSuite extends PlanTest { | |
} | ||
} | ||
|
||
Seq(LeftSemi, LeftAnti).foreach { case jt => | ||
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") { | ||
val aggregation = testRelation | ||
.select($"b".as("id"), $"c") | ||
.groupBy($"id")($"id", sum($"c").as("sum")) | ||
|
||
// reference "b" exists in left leg, and the children of the right leg of the join | ||
val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum") | ||
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1")) | ||
val optimized = Optimize.execute(originalQuery.analyze) | ||
val correctAnswer = testRelation | ||
.select($"b".as("id"), $"c") | ||
.groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum")) | ||
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1")) | ||
.analyze | ||
comparePlans(optimized, correctAnswer) | ||
} | ||
} | ||
|
||
Seq(LeftSemi, LeftAnti).foreach { case outerJT => | ||
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT => | ||
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should rewrite the
joinCondition
assuming the join has already been pushed through Aggregate. That said, we need to do alias replacement forjoinCondition
first. cc @EnricoMiThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand. The
canPushThroughCondition
is called before theJoin
is being pushed through theAggregate
, it has been added to prevent this from happening in this situation. The other cases (e.g.Union
) are calling intocanPushThroughCondition
equivalently.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm,
canPushThroughCondition
checks the right side references of the join condition, and check if the right side references have conflict expr ID with left side plan (below Project) output. It doesn't care about the left side references of the join condition.