From 0feb4f210c47137dfdd2b9c01d67d047663ad539 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 30 Apr 2026 16:02:24 +0200 Subject: [PATCH 1/2] [SPARK-56677][SQL] Propagate filter conditions through Join nodes in PlanMerger ### What changes were proposed in this pull request? `PlanMerger` now supports filter propagation through `Join` nodes when merging similar subplans. Previously, when two subplans contained identical `Join` nodes but differed only in a filter applied to one of the join's children, they could not be merged. This PR adds the ability to propagate such filter conditions through a `Join` and into the parent `Aggregate`'s `FILTER` clause. A new `filterSafeForJoin` helper checks that the filter originates from the non-nullable (preserved) side of the join: the left side of `LeftOuter`/`LeftSemi`/`LeftAnti`, the right side of `RightOuter`, or either side of `Inner`/`Cross`. `FullOuter` joins are not eligible. The feature is gated by a new SQL config: `spark.sql.optimizer.mergeSubplans.filterPropagationThroughJoin.enabled` (default: `true`). ### Why are the changes needed? Without this change, scalar subqueries that differ only in a filter on one side of an identical join cannot be merged, resulting in redundant scans and compute. For example: SELECT (SELECT sum(key) FROM t1 JOIN t2 ON t1.id = t2.id), (SELECT sum(key) FROM t1 JOIN t2 ON t1.id = t2.id WHERE t2.b > 1) Both subqueries scan `t1` and `t2` in full even though they share the same base join. After this change a single merged scan is used and the second subquery's result is derived from it via an aggregate `FILTER` clause. ### Does this PR introduce _any_ user-facing change? Yes. The optimizer may now merge scalar subqueries that were previously kept separate, reducing the number of scan and join operations. The new config `spark.sql.optimizer.mergeSubplans.filterPropagationThroughJoin.enabled` (default `true`) can be used to opt out. ### How was this patch tested? Added unit tests in `MergeSubplansSuite`: - Merge with filter on left inner join child - Merge with filter on right inner join child - No merge when both join children have independent filters - Merge with filter on the preserved side of a `LeftSemi` join - No merge when filter is on the non-output side of a `LeftSemi` join - No merge when filter is on the nullable side of an outer join - No merge when the feature is disabled via config Added integration test in `PlanMergeSuite` verifying correctness (`checkAnswer`) and plan shape (`SubqueryExec`/`ReusedSubqueryExec` counts) for both the enabled and disabled config cases, with and without AQE. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Sonnet 4.6 --- .../sql/catalyst/optimizer/PlanMerger.scala | 71 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 15 ++ .../optimizer/MergeSubplansSuite.scala | 176 ++++++++++++++++++ .../org/apache/spark/sql/PlanMergeSuite.scala | 43 +++++ 4 files changed, 296 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index a85bed783de6e..91ad88d3b372f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, If, Literal, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.internal.SQLConf @@ -91,6 +92,11 @@ object PlanMerger { * When plans also differ in intermediate [[Project]] expressions, those are wrapped with * `If(filterAttr, expr, null)` to avoid computing the expression for rows that do not * match that side's filter condition. + * Filter propagation also works through [[Join]] nodes: a filter on one child of the join + * produces a boolean attribute that flows through the join output to the enclosing + * [[Aggregate]]. Propagation is skipped when both the left and right children simultaneously + * produce filter attributes, as combining them would require an additional AND alias above + * the join (not yet supported). * * {{{ * // Input plans @@ -120,7 +126,9 @@ class PlanMerger( filterPropagationEnabled: Boolean = SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED), symmetricFilterPropagationEnabled: Boolean = - SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED)) { + SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED), + filterPropagationThroughJoinEnabled: Boolean = + SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED)) { val cache = mutable.ArrayBuffer.empty[MergedPlan] /** @@ -224,7 +232,8 @@ class PlanMerger( * - Aggregate nodes: Combines aggregate expressions if grouping is identical and both * support the same aggregate implementation (hash/object-hash/sort-based) * - Filter nodes: Only if filter conditions are identical - * - Join nodes: Only if join type, hints, and conditions are identical + * - Join nodes: Requires identical join type, hints, and conditions; filter propagation is + * forwarded into the join's children so a filter difference on one child can still be merged * * @param newPlan The plan to merge into the cached plan. * @param cachedPlan The cached plan to merge with. @@ -416,18 +425,37 @@ class PlanMerger( } case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => - // Filter propagation across joins is not yet supported. - tryMergePlans(np.left, cp.left, false).flatMap { - case TryMergeResult(mergedLeft, leftNPMapping, None, None) => - tryMergePlans(np.right, cp.right, false).flatMap { - case TryMergeResult(mergedRight, rightNPMapping, None, None) => + tryMergePlans(np.left, cp.left, filterPropagationSupported).flatMap { + case TryMergeResult(mergedLeft, leftNPMapping, leftNPFilter, leftCPFilter) => + tryMergePlans(np.right, cp.right, filterPropagationSupported).flatMap { + case TryMergeResult(mergedRight, rightNPMapping, rightNPFilter, rightCPFilter) + // If both children independently propagate filter attributes we would need to + // AND them into a new alias above the join, which is not yet supported. + if !(leftNPFilter.isDefined && rightNPFilter.isDefined) && + !(leftCPFilter.isDefined && rightCPFilter.isDefined) && + // Gate join-crossing filter propagation behind its own config flag. + // When no filter attributes are in play the merge is unconditionally safe. + (leftNPFilter.isEmpty && leftCPFilter.isEmpty && + rightNPFilter.isEmpty && rightCPFilter.isEmpty || + filterPropagationThroughJoinEnabled) && + // A filter attribute is only safe to propagate through a join if it comes + // from the "preserved" (non-nullable) side. On the nullable side, unmatched + // rows are NULL-padded so f=NULL, causing FILTER (WHERE f) to incorrectly + // exclude rows that should contribute to the aggregate. Right-side + // attributes are also absent from semi/anti join output. + (leftNPFilter.isEmpty && leftCPFilter.isEmpty || + filterSafeForJoin(fromLeft = true, cp.joinType)) && + (rightNPFilter.isEmpty && rightCPFilter.isEmpty || + filterSafeForJoin(fromLeft = false, cp.joinType)) => val npMapping = leftNPMapping ++ rightNPMapping val mappedNPCondition = np.condition.map(mapAttributes(_, npMapping)) // Comparing the canonicalized form is required to ignore different forms of the // same expression and `AttributeReference.qualifier`s in `cp.condition`. if (mappedNPCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)) { - val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) - Some(TryMergeResult(mergedPlan, npMapping)) + val npFilter = leftNPFilter.orElse(rightNPFilter) + val cpFilter = leftCPFilter.orElse(rightCPFilter) + Some(TryMergeResult(cp.withNewChildren(Seq(mergedLeft, mergedRight)), npMapping, + npFilter, cpFilter)) } else { None } @@ -441,6 +469,31 @@ class PlanMerger( }) } + // Returns true when a filter attribute originating from `fromLeft` child of a join with + // `joinType` can be safely propagated through that join to a parent Aggregate. + // + // Two conditions must both hold: + // 1. The attribute is in the join's output (rules out the right side of LeftSemi/LeftAnti). + // 2. The attribute is never NULL in the join's output, i.e. it comes from the "preserved" + // side that is never NULL-padded. For an outer join, unmatched rows from the nullable + // side are padded with NULLs, so a filter attribute from that side would be NULL for + // those rows. FILTER (WHERE NULL) would then incorrectly exclude those rows from the + // aggregate, even though they appear in the original join result and should contribute. + private def filterSafeForJoin(fromLeft: Boolean, joinType: JoinType): Boolean = + if (fromLeft) { + // Left side is never NULL-padded in: Inner, LeftOuter, LeftSemi, LeftAnti, Cross. + joinType match { + case Inner | LeftOuter | LeftSemi | LeftAnti | Cross => true + case _ => false // RightOuter and FullOuter can NULL-pad the left side + } + } else { + // Right side is never NULL-padded AND is in the join output in: Inner, RightOuter, Cross. + joinType match { + case Inner | RightOuter | Cross => true + case _ => false // LeftOuter/FullOuter can NULL-pad right; LeftSemi/LeftAnti drop right + } + } + private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { expr.transform { case a: Attribute => outputMap.getOrElse(a, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7fbf2797dc93d..af362c97d58ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6608,6 +6608,21 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED = + buildConf( + "spark.sql.optimizer.mergeSubplans.filterPropagation.filterPropagationThroughJoin.enabled") + .doc("When set to true, filter attributes can propagate through Join nodes during subplan " + + "merging, allowing subplans that differ only in their filter conditions and share a " + + "common join to be merged into a single scan. A filter attribute is only propagated " + + "through a join when it originates from the non-nullable (preserved) side: the left side " + + "of LeftOuter/LeftSemi/LeftAnti, the right side of RightOuter, or either side of " + + "Inner/Cross. FullOuter joins are never eligible. " + + s"Has no effect when ${MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED.key} is false.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ERROR_MESSAGE_FORMAT = buildConf("spark.sql.error.messageFormat") .doc("When PRETTY, the error message consists of textual representation of error class, " + "message and query context. Stack traces are only shown for internal errors " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index e685c756a4b73..c2b69ecf1e0d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -37,6 +37,7 @@ class MergeSubplansSuite extends PlanTest { } val testRelation = LocalRelation($"a".int, $"b".int, $"c".string) + val testRelation2 = LocalRelation($"d".int, $"e".int) val testRelationWithNonBinaryCollation = LocalRelation( $"utf8_binary".string("UTF8_BINARY"), $"utf8_lcase".string("UTF8_LCASE")) @@ -1542,4 +1543,179 @@ class MergeSubplansSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) } + + test("SPARK-56677: Merge non-grouping subqueries with filter on left join child") { + // cp (subquery1): Aggregate([], [sum(a)], Join(testRelation, testRelation2, a=d)) + // np (subquery2): Aggregate([], [max(a)], Join(Filter(a>1, testRelation), testRelation2, a=d)) + // The filter on the left join child propagates as a boolean attribute through the Join node + // and is consumed as a FILTER (WHERE ...) clause on the np-side aggregate expression. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"a", Some(f0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Merge non-grouping subqueries with filter on right join child") { + // cp (subquery1): Aggregate([], [sum(a)], Join(testRelation, testRelation2, a=d)) + // np (subquery2): Aggregate([], [max(d)], Join(testRelation, Filter(d>1, testRelation2), a=d)) + // The filter on the right join child propagates analogously to the left-child case. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.join(testRelation2.where($"d" > 1), Inner, Some($"a" === $"d")) + .groupBy()(max($"d").as("max_d"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"d" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .join( + testRelation2.select(testRelation2.output ++ Seq(f0Alias): _*), + Inner, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"d", Some(f0)).as("max_d")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_d"), $"max_d" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when both join children have independent filters") { + // np has filters on BOTH left and right join children simultaneously. The guard in the + // Join case prevents this merge because combining two independent filter attributes would + // require ANDing them into a new alias, which is not yet supported. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2.where($"d" > 1), Inner, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Merge non-grouping subqueries with filter on left side of LeftSemi join") { + // Left-side filter attributes ARE in the LeftSemi join output, so propagation is safe. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"a", Some(f0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when filter is on the right side of a LeftSemi join") { + // Right-side filter attributes are NOT in the LeftSemi join output (only left-side columns + // are produced). Propagating such a filter would create an unresolvable attribute reference + // in the parent Aggregate's FILTER clause. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.join(testRelation2.where($"d" > 1), LeftSemi, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when filter is on the nullable side of an outer " + + "join") { + // For a RightOuter join the left side is nullable: unmatched right rows produce NULL for all + // left-side columns including the filter attribute f, so FILTER (WHERE f=NULL) would + // incorrectly exclude those rows from the aggregate even though they appear in the join result. + // The same problem applies to the right side of a LeftOuter join and both sides of FullOuter. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, RightOuter, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, RightOuter, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries with filter propagation through join when disabled") { + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "false") { + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala index 1e31453b42f23..e1109f20e6040 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala @@ -395,4 +395,47 @@ class PlanMergeSuite extends SharedSparkSession } } } + + test("SPARK-56677: Merge scalar subqueries with filter propagation through Join") { + // subquery1 has no filter; subquery2 filters on b > 1 (a column from the right side of the join + // that is not part of the join condition). Predicate pushdown can only push this filter to + // testData2, not to testData, so only the right child differs between the two subqueries. + Seq(false, true).foreach { enableAQE => + Seq(true, false).foreach { filterPropagationThroughJoinEnabled => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> + filterPropagationThroughJoinEnabled.toString, + // ObjectSerializerPruning produces different scan shapes depending on whether a Filter is + // present. Disabling the rule makes both scans identical so PlanMerger can merge them. + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning") { + val df = sql( + """ + |SELECT + | (SELECT sum(key) FROM testData JOIN testData2 ON key = a), + | (SELECT sum(key) FROM testData JOIN testData2 ON key = a WHERE b > 1) + """.stripMargin) + + checkAnswer(df, Row(12, 6) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + if (filterPropagationThroughJoinEnabled) { + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 1, + "Missing or unexpected ReusedSubqueryExec in the plan") + } else { + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 0, + "Missing or unexpected ReusedSubqueryExec in the plan") + } + } + } + } + } } From 4d87515c6de9583e4f6999094cb91a6d4dd70312 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 1 May 2026 15:24:46 +0200 Subject: [PATCH 2/2] address review findings --- .../sql/catalyst/optimizer/PlanMerger.scala | 41 +++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../optimizer/MergeSubplansSuite.scala | 51 +++++++++++++++++++ 3 files changed, 79 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index 91ad88d3b372f..a85fad2f758c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -86,17 +86,24 @@ object PlanMerger { * When `filterPropagationEnabled` is true, non-grouping [[Aggregate]]s over the same base plan * with different [[Filter]] conditions can also be merged. The filter conditions are exposed as * boolean [[Project]] attributes and consumed at the [[Aggregate]] as FILTER clauses. - * When both sides carry a [[Filter]] (the symmetric case), merging broadens the scan to - * OR(f1, f2), which may reduce IO pruning. This path is separately gated by + * When both sides carry a [[Filter]] (the symmetric case), merging broadens the scan to OR(f1, f2), + * which may reduce IO pruning. This path is separately gated by * `symmetricFilterPropagationEnabled`. * When plans also differ in intermediate [[Project]] expressions, those are wrapped with - * `If(filterAttr, expr, null)` to avoid computing the expression for rows that do not - * match that side's filter condition. - * Filter propagation also works through [[Join]] nodes: a filter on one child of the join - * produces a boolean attribute that flows through the join output to the enclosing - * [[Aggregate]]. Propagation is skipped when both the left and right children simultaneously - * produce filter attributes, as combining them would require an additional AND alias above - * the join (not yet supported). + * `If(filterAttr, expr, null)` to avoid computing the expression for rows that do not match that + * side's filter condition. + * Filter propagation also works through [[Join]] nodes: a filter on one child of the join produces + * a boolean attribute that flows through the join output to the enclosing [[Aggregate]]. + * Propagation is only safe when the filter originates from the non-nullable side of the join, as + * enforced by `filterSafeForJoin`. When the filter is on the nullable side, the merged base plan + * restores rows that were filtered out of the nullable child, turning what were unmatched + * NULL-padded rows in the original plan into matched rows with real column values. This changes the + * result of expressions like `coalesce(col, default)` in the aggregate: an originally unmatched row + * would have contributed `default` via `coalesce(NULL, default)`, but in the merged plan it is + * matched, its real column value fails the filter, and `FILTER (WHERE false)` discards it entirely. + * Propagation is also skipped when both the left and right children simultaneously produce filter + * attributes, as combining them would require an additional AND alias above the join (not yet + * supported). * * {{{ * // Input plans @@ -443,7 +450,7 @@ class PlanMerger( // rows are NULL-padded so f=NULL, causing FILTER (WHERE f) to incorrectly // exclude rows that should contribute to the aggregate. Right-side // attributes are also absent from semi/anti join output. - (leftNPFilter.isEmpty && leftCPFilter.isEmpty || + (leftNPFilter.isEmpty && leftCPFilter.isEmpty || filterSafeForJoin(fromLeft = true, cp.joinType)) && (rightNPFilter.isEmpty && rightCPFilter.isEmpty || filterSafeForJoin(fromLeft = false, cp.joinType)) => @@ -474,11 +481,15 @@ class PlanMerger( // // Two conditions must both hold: // 1. The attribute is in the join's output (rules out the right side of LeftSemi/LeftAnti). - // 2. The attribute is never NULL in the join's output, i.e. it comes from the "preserved" - // side that is never NULL-padded. For an outer join, unmatched rows from the nullable - // side are padded with NULLs, so a filter attribute from that side would be NULL for - // those rows. FILTER (WHERE NULL) would then incorrectly exclude those rows from the - // aggregate, even though they appear in the original join result and should contribute. + // 2. The filter must originate from the non-nullable ("preserved") side of the join. + // When a filter is on the nullable side, the merged base plan no longer applies it to the + // nullable child's scan, so rows that were previously absent from that child reappear as + // matched join rows instead of unmatched NULL-padded rows. This changes aggregate + // expressions that use the NULL-padded column: e.g. for `sum(coalesce(col, default))`, an + // originally unmatched row would have contributed `default` via `coalesce(NULL, default)`, + // but in the merged plan the row is now matched with its real column value, fails the + // filter, and FILTER (WHERE false) discards it -- losing the `default` contribution + // entirely. private def filterSafeForJoin(fromLeft: Boolean, joinType: JoinType): Boolean = if (fromLeft) { // Left side is never NULL-padded in: Inner, LeftOuter, LeftSemi, LeftAnti, Cross. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index af362c97d58ac..83f1816c9727f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6609,13 +6609,12 @@ object SQLConf { .createWithDefault(false) val MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED = - buildConf( - "spark.sql.optimizer.mergeSubplans.filterPropagation.filterPropagationThroughJoin.enabled") + buildConf("spark.sql.optimizer.mergeSubplans.filterPropagation.throughJoin.enabled") .doc("When set to true, filter attributes can propagate through Join nodes during subplan " + "merging, allowing subplans that differ only in their filter conditions and share a " + "common join to be merged into a single scan. A filter attribute is only propagated " + "through a join when it originates from the non-nullable (preserved) side: the left side " + - "of LeftOuter/LeftSemi/LeftAnti, the right side of RightOuter, or either side of " + + "of LeftOuter/LeftSemi/LeftAnti, the right side of RightOuter, or either side of " + "Inner/Cross. FullOuter joins are never eligible. " + s"Has no effect when ${MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED.key} is false.") .version("4.2.0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index c2b69ecf1e0d5..c9e5db5810597 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -1618,6 +1618,40 @@ class MergeSubplansSuite extends PlanTest { } } + test("SPARK-56677: Merge non-grouping subqueries with filter on left child of a Cross join") { + // Cross join never NULL-pads either side, so filter propagation is safe from both sides. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Cross, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, Cross, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .join(testRelation2, Cross, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"a", Some(f0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + test("SPARK-56677: Do not merge subqueries when both join children have independent filters") { // np has filters on BOTH left and right join children simultaneously. The guard in the // Join case prevents this merge because combining two independent filter attributes would @@ -1705,6 +1739,23 @@ class MergeSubplansSuite extends PlanTest { } } + test("SPARK-56677: Do not merge subqueries when filter is on either side of a FullOuter join") { + // For a FullOuter join both sides are nullable: unmatched rows from either side produce NULL + // for the other side's columns. A filter attribute from either side would be NULL for those + // unmatched rows, making propagation unsafe from both sides. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, FullOuter, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, FullOuter, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + test("SPARK-56677: Do not merge subqueries with filter propagation through join when disabled") { withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "false") { val subquery1 = ScalarSubquery(