Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41162][SQL] Fix anti- and semi-join for self-join with aggregations #39131

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
}

// LeftSemi/LeftAnti over Aggregate, only push down if join can be planned as broadcast join.
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _)
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
canPushThroughCondition(agg.children, joinCond, rightOp) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rewrite the joinCondition assuming the join has already been pushed through Aggregate. That said, we need to do alias replacement for joinCondition first. cc @EnricoMi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. The canPushThroughCondition is called before the Join is being pushed through the Aggregate, it has been added to prevent this from happening in this situation. The other cases (e.g. Union) are calling into canPushThroughCondition equivalently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, canPushThroughCondition checks the right side references of the join condition, and check if the right side references have conflict expr ID with left side plan (below Project) output. It doesn't care about the left side references of the join condition.

canPlanAsBroadcastHashJoin(join, conf) =>
val aliasMap = getAliasMap(agg)
val canPushDownPredicate = (predicate: Expression) => {
Expand Down Expand Up @@ -110,11 +111,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
}

/**
* Check if we can safely push a join through a project or union by making sure that attributes
* referred in join condition do not contain the same attributes as the plan they are moved
* into. This can happen when both sides of join refers to the same source (self join). This
* function makes sure that the join condition refers to attributes that are not ambiguous (i.e
* present in both the legs of the join) or else the resultant plan will be invalid.
* Check if we can safely push a join through a project, aggregate, or union by making sure that
* attributes referred in join condition do not contain the same attributes as the plan they are
* moved into. This can happen when both sides of join refers to the same source (self join).
* This function makes sure that the join condition refers to attributes that are not ambiguous
* (i.e present in both the legs of the join) or else the resultant plan will be invalid.
*/
private def canPushThroughCondition(
plans: Seq[LogicalPlan],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class LeftSemiPushdownSuite extends PlanTest {
class LeftSemiAntiJoinPushDownSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand All @@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest {
val testRelation1 = LocalRelation($"d".int)
val testRelation2 = LocalRelation($"e".int)

test("Project: LeftSemiAnti join pushdown") {
test("Project: LeftSemi join pushdown") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These change to test names are necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The term LeftSemiAnti is wrong and misleading for individual tests, correcting this while I am touching the file.

val originalQuery = testRelation
.select(star())
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d"))
Expand All @@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") {
val originalQuery = testRelation
.select(Rand(1), $"b", $"c")
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d"))
Expand All @@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Project: LeftSemiAnti join non correlated scalar subq") {
test("Project: LeftSemi join pushdown - non-correlated scalar subq") {
val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze)
val originalQuery = testRelation
.select(subq.as("sum"))
Expand All @@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") {
test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") {
val testRelation2 = LocalRelation($"e".int, $"f".int)
val subqPlan = testRelation2.groupBy($"e")(sum($"f").as("sum")).where($"e" === $"a")
val subqExpr = ScalarSubquery(subqPlan)
Expand All @@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Aggregate: LeftSemiAnti join pushdown") {
test("Aggregate: LeftSemi join pushdown") {
val originalQuery = testRelation
.groupBy($"b")($"b", sum($"c"))
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d"))
Expand All @@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") {
test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") {
val originalQuery = testRelation
.groupBy($"b")($"b", Rand(10).as("c"))
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d"))
Expand Down Expand Up @@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("LeftSemiAnti join over aggregate - no pushdown") {
test("Aggregate: LeftSemi join no pushdown") {
val originalQuery = testRelation
.groupBy($"b")($"b", sum($"c").as("sum"))
.join(testRelation1, joinType = LeftSemi, condition = Some($"b" === $"d" && $"sum" === $"d"))
Expand All @@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") {
val subq = ScalarSubquery(testRelation.groupBy($"b")(sum($"c").as("sum")).analyze)
val originalQuery = testRelation
.groupBy($"a") ($"a", subq.as("sum"))
Expand All @@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("LeftSemiAnti join over Window") {
test("Window: LeftSemi join pushdown") {
val winExpr = windowExpr(count($"b"),
windowSpec($"a" :: Nil, $"b".asc :: Nil, UnspecifiedFrame))

Expand All @@ -185,7 +185,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Window: LeftSemi partial pushdown") {
test("Window: LeftSemi join partial pushdown") {
// Attributes from join condition which does not refer to the window partition spec
// are kept up in the plan as a Filter operator above Window.
val winExpr = windowExpr(count($"b"),
Expand Down Expand Up @@ -227,7 +227,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Union: LeftSemiAnti join pushdown") {
test("Union: LeftSemi join pushdown") {
val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int)

val originalQuery = Union(Seq(testRelation, testRelation2))
Expand All @@ -243,7 +243,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Union: LeftSemiAnti join pushdown in self join scenario") {
test("Union: LeftSemi join pushdown in self join scenario") {
val testRelation2 = LocalRelation($"x".int, $"y".int, $"z".int)
val attrX = testRelation2.output.head

Expand All @@ -262,7 +262,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemiAnti join pushdown") {
test("Unary: LeftSemi join pushdown") {
val originalQuery = testRelation
.select(star())
.repartition(1)
Expand All @@ -277,7 +277,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemiAnti join pushdown - empty join condition") {
test("Unary: LeftSemi join pushdown - empty join condition") {
val originalQuery = testRelation
.select(star())
.repartition(1)
Expand All @@ -292,7 +292,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemi join pushdown - partial pushdown") {
test("Unary: LeftSemi join partial pushdown") {
val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -309,7 +309,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftAnti join pushdown - no pushdown") {
test("Unary: LeftAnti join no pushdown") {
val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -320,7 +320,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Unary: LeftSemiAnti join pushdown - no pushdown") {
test("Unary: LeftSemi join - no pushdown") {
val testRelationWithArrayType = LocalRelation($"a".int, $"b".int, $"c_arr".array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode($"c_arr"), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -331,7 +331,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Unary: LeftSemi join push down through Expand") {
test("Unary: LeftSemi join pushdown through Expand") {
val expand = Expand(Seq(Seq($"a", $"b", "null"), Seq($"a", "null", $"c")),
Seq($"a", $"b", $"c"), testRelation)
val originalQuery = expand
Expand Down Expand Up @@ -437,6 +437,25 @@ class LeftSemiPushdownSuite extends PlanTest {
}
}

Seq(LeftSemi, LeftAnti).foreach { case jt =>
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
val aggregation = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")($"id", sum($"c").as("sum"))

// reference "b" exists in left leg, and the children of the right leg of the join
val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum")
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum"))
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
.analyze
comparePlans(optimized, correctAnswer)
}
}

Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,24 @@ class DataFrameJoinSuite extends QueryTest
}
}

Seq("left_semi", "left_anti").foreach { joinType =>
test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
// aggregated dataframe
val ids = Seq(1, 2, 3).toDF("id").distinct()

// self-joined via joinType
val result = ids.withColumn("id", $"id" + 1)
.join(ids, "id", joinType).collect()

val expected = joinType match {
case "left_semi" => 2
case "left_anti" => 1
case _ => -1 // unsupported test type, test will always fail
}
assert(result.length == expected)
}
}

def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left)
case Filter(_, child) => extractLeftDeepInnerJoins(child)
Expand Down