diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index a85bed783de6..a85fad2f758c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, If, Literal, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.internal.SQLConf @@ -85,12 +86,24 @@ object PlanMerger { * When `filterPropagationEnabled` is true, non-grouping [[Aggregate]]s over the same base plan * with different [[Filter]] conditions can also be merged. The filter conditions are exposed as * boolean [[Project]] attributes and consumed at the [[Aggregate]] as FILTER clauses. - * When both sides carry a [[Filter]] (the symmetric case), merging broadens the scan to - * OR(f1, f2), which may reduce IO pruning. This path is separately gated by + * When both sides carry a [[Filter]] (the symmetric case), merging broadens the scan to OR(f1, f2), + * which may reduce IO pruning. This path is separately gated by * `symmetricFilterPropagationEnabled`. * When plans also differ in intermediate [[Project]] expressions, those are wrapped with - * `If(filterAttr, expr, null)` to avoid computing the expression for rows that do not - * match that side's filter condition. + * `If(filterAttr, expr, null)` to avoid computing the expression for rows that do not match that + * side's filter condition. + * Filter propagation also works through [[Join]] nodes: a filter on one child of the join produces + * a boolean attribute that flows through the join output to the enclosing [[Aggregate]]. + * Propagation is only safe when the filter originates from the non-nullable side of the join, as + * enforced by `filterSafeForJoin`. When the filter is on the nullable side, the merged base plan + * restores rows that were filtered out of the nullable child, turning what were unmatched + * NULL-padded rows in the original plan into matched rows with real column values. This changes the + * result of expressions like `coalesce(col, default)` in the aggregate: an originally unmatched row + * would have contributed `default` via `coalesce(NULL, default)`, but in the merged plan it is + * matched, its real column value fails the filter, and `FILTER (WHERE false)` discards it entirely. + * Propagation is also skipped when both the left and right children simultaneously produce filter + * attributes, as combining them would require an additional AND alias above the join (not yet + * supported). * * {{{ * // Input plans @@ -120,7 +133,9 @@ class PlanMerger( filterPropagationEnabled: Boolean = SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED), symmetricFilterPropagationEnabled: Boolean = - SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED)) { + SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED), + filterPropagationThroughJoinEnabled: Boolean = + SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED)) { val cache = mutable.ArrayBuffer.empty[MergedPlan] /** @@ -224,7 +239,8 @@ class PlanMerger( * - Aggregate nodes: Combines aggregate expressions if grouping is identical and both * support the same aggregate implementation (hash/object-hash/sort-based) * - Filter nodes: Only if filter conditions are identical - * - Join nodes: Only if join type, hints, and conditions are identical + * - Join nodes: Requires identical join type, hints, and conditions; filter propagation is + * forwarded into the join's children so a filter difference on one child can still be merged * * @param newPlan The plan to merge into the cached plan. * @param cachedPlan The cached plan to merge with. @@ -416,18 +432,37 @@ class PlanMerger( } case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => - // Filter propagation across joins is not yet supported. - tryMergePlans(np.left, cp.left, false).flatMap { - case TryMergeResult(mergedLeft, leftNPMapping, None, None) => - tryMergePlans(np.right, cp.right, false).flatMap { - case TryMergeResult(mergedRight, rightNPMapping, None, None) => + tryMergePlans(np.left, cp.left, filterPropagationSupported).flatMap { + case TryMergeResult(mergedLeft, leftNPMapping, leftNPFilter, leftCPFilter) => + tryMergePlans(np.right, cp.right, filterPropagationSupported).flatMap { + case TryMergeResult(mergedRight, rightNPMapping, rightNPFilter, rightCPFilter) + // If both children independently propagate filter attributes we would need to + // AND them into a new alias above the join, which is not yet supported. + if !(leftNPFilter.isDefined && rightNPFilter.isDefined) && + !(leftCPFilter.isDefined && rightCPFilter.isDefined) && + // Gate join-crossing filter propagation behind its own config flag. + // When no filter attributes are in play the merge is unconditionally safe. + (leftNPFilter.isEmpty && leftCPFilter.isEmpty && + rightNPFilter.isEmpty && rightCPFilter.isEmpty || + filterPropagationThroughJoinEnabled) && + // A filter attribute is only safe to propagate through a join if it comes + // from the "preserved" (non-nullable) side. On the nullable side, unmatched + // rows are NULL-padded so f=NULL, causing FILTER (WHERE f) to incorrectly + // exclude rows that should contribute to the aggregate. Right-side + // attributes are also absent from semi/anti join output. + (leftNPFilter.isEmpty && leftCPFilter.isEmpty || + filterSafeForJoin(fromLeft = true, cp.joinType)) && + (rightNPFilter.isEmpty && rightCPFilter.isEmpty || + filterSafeForJoin(fromLeft = false, cp.joinType)) => val npMapping = leftNPMapping ++ rightNPMapping val mappedNPCondition = np.condition.map(mapAttributes(_, npMapping)) // Comparing the canonicalized form is required to ignore different forms of the // same expression and `AttributeReference.qualifier`s in `cp.condition`. if (mappedNPCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)) { - val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) - Some(TryMergeResult(mergedPlan, npMapping)) + val npFilter = leftNPFilter.orElse(rightNPFilter) + val cpFilter = leftCPFilter.orElse(rightCPFilter) + Some(TryMergeResult(cp.withNewChildren(Seq(mergedLeft, mergedRight)), npMapping, + npFilter, cpFilter)) } else { None } @@ -441,6 +476,35 @@ class PlanMerger( }) } + // Returns true when a filter attribute originating from `fromLeft` child of a join with + // `joinType` can be safely propagated through that join to a parent Aggregate. + // + // Two conditions must both hold: + // 1. The attribute is in the join's output (rules out the right side of LeftSemi/LeftAnti). + // 2. The filter must originate from the non-nullable ("preserved") side of the join. + // When a filter is on the nullable side, the merged base plan no longer applies it to the + // nullable child's scan, so rows that were previously absent from that child reappear as + // matched join rows instead of unmatched NULL-padded rows. This changes aggregate + // expressions that use the NULL-padded column: e.g. for `sum(coalesce(col, default))`, an + // originally unmatched row would have contributed `default` via `coalesce(NULL, default)`, + // but in the merged plan the row is now matched with its real column value, fails the + // filter, and FILTER (WHERE false) discards it -- losing the `default` contribution + // entirely. + private def filterSafeForJoin(fromLeft: Boolean, joinType: JoinType): Boolean = + if (fromLeft) { + // Left side is never NULL-padded in: Inner, LeftOuter, LeftSemi, LeftAnti, Cross. + joinType match { + case Inner | LeftOuter | LeftSemi | LeftAnti | Cross => true + case _ => false // RightOuter and FullOuter can NULL-pad the left side + } + } else { + // Right side is never NULL-padded AND is in the join output in: Inner, RightOuter, Cross. + joinType match { + case Inner | RightOuter | Cross => true + case _ => false // LeftOuter/FullOuter can NULL-pad right; LeftSemi/LeftAnti drop right + } + } + private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { expr.transform { case a: Attribute => outputMap.getOrElse(a, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7fbf2797dc93..83f1816c9727 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6608,6 +6608,20 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED = + buildConf("spark.sql.optimizer.mergeSubplans.filterPropagation.throughJoin.enabled") + .doc("When set to true, filter attributes can propagate through Join nodes during subplan " + + "merging, allowing subplans that differ only in their filter conditions and share a " + + "common join to be merged into a single scan. A filter attribute is only propagated " + + "through a join when it originates from the non-nullable (preserved) side: the left side " + + "of LeftOuter/LeftSemi/LeftAnti, the right side of RightOuter, or either side of " + + "Inner/Cross. FullOuter joins are never eligible. " + + s"Has no effect when ${MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED.key} is false.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ERROR_MESSAGE_FORMAT = buildConf("spark.sql.error.messageFormat") .doc("When PRETTY, the error message consists of textual representation of error class, " + "message and query context. Stack traces are only shown for internal errors " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index e685c756a4b7..c9e5db581059 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -37,6 +37,7 @@ class MergeSubplansSuite extends PlanTest { } val testRelation = LocalRelation($"a".int, $"b".int, $"c".string) + val testRelation2 = LocalRelation($"d".int, $"e".int) val testRelationWithNonBinaryCollation = LocalRelation( $"utf8_binary".string("UTF8_BINARY"), $"utf8_lcase".string("UTF8_LCASE")) @@ -1542,4 +1543,230 @@ class MergeSubplansSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) } + + test("SPARK-56677: Merge non-grouping subqueries with filter on left join child") { + // cp (subquery1): Aggregate([], [sum(a)], Join(testRelation, testRelation2, a=d)) + // np (subquery2): Aggregate([], [max(a)], Join(Filter(a>1, testRelation), testRelation2, a=d)) + // The filter on the left join child propagates as a boolean attribute through the Join node + // and is consumed as a FILTER (WHERE ...) clause on the np-side aggregate expression. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"a", Some(f0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Merge non-grouping subqueries with filter on right join child") { + // cp (subquery1): Aggregate([], [sum(a)], Join(testRelation, testRelation2, a=d)) + // np (subquery2): Aggregate([], [max(d)], Join(testRelation, Filter(d>1, testRelation2), a=d)) + // The filter on the right join child propagates analogously to the left-child case. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.join(testRelation2.where($"d" > 1), Inner, Some($"a" === $"d")) + .groupBy()(max($"d").as("max_d"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"d" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .join( + testRelation2.select(testRelation2.output ++ Seq(f0Alias): _*), + Inner, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"d", Some(f0)).as("max_d")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_d"), $"max_d" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Merge non-grouping subqueries with filter on left child of a Cross join") { + // Cross join never NULL-pads either side, so filter propagation is safe from both sides. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Cross, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, Cross, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .join(testRelation2, Cross, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"a", Some(f0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when both join children have independent filters") { + // np has filters on BOTH left and right join children simultaneously. The guard in the + // Join case prevents this merge because combining two independent filter attributes would + // require ANDing them into a new alias, which is not yet supported. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2.where($"d" > 1), Inner, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Merge non-grouping subqueries with filter on left side of LeftSemi join") { + // Left-side filter attributes ARE in the LeftSemi join output, so propagation is safe. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()( + sum($"a").as("sum_a"), + max($"a", Some(f0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when filter is on the right side of a LeftSemi join") { + // Right-side filter attributes are NOT in the LeftSemi join output (only left-side columns + // are produced). Propagating such a filter would create an unresolvable attribute reference + // in the parent Aggregate's FILTER clause. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, LeftSemi, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.join(testRelation2.where($"d" > 1), LeftSemi, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when filter is on the nullable side of an outer " + + "join") { + // For a RightOuter join the left side is nullable: unmatched right rows produce NULL for all + // left-side columns including the filter attribute f, so FILTER (WHERE f=NULL) would + // incorrectly exclude those rows from the aggregate even though they appear in the join result. + // The same problem applies to the right side of a LeftOuter join and both sides of FullOuter. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, RightOuter, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, RightOuter, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries when filter is on either side of a FullOuter join") { + // For a FullOuter join both sides are nullable: unmatched rows from either side produce NULL + // for the other side's columns. A filter attribute from either side would be NULL for those + // unmatched rows, making propagation unsafe from both sides. + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, FullOuter, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, FullOuter, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-56677: Do not merge subqueries with filter propagation through join when disabled") { + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> "false") { + val subquery1 = ScalarSubquery( + testRelation.join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).join(testRelation2, Inner, Some($"a" === $"d")) + .groupBy()(max($"a").as("max_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala index 1e31453b42f2..e1109f20e604 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala @@ -395,4 +395,47 @@ class PlanMergeSuite extends SharedSparkSession } } } + + test("SPARK-56677: Merge scalar subqueries with filter propagation through Join") { + // subquery1 has no filter; subquery2 filters on b > 1 (a column from the right side of the join + // that is not part of the join condition). Predicate pushdown can only push this filter to + // testData2, not to testData, so only the right child differs between the two subqueries. + Seq(false, true).foreach { enableAQE => + Seq(true, false).foreach { filterPropagationThroughJoinEnabled => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key -> + filterPropagationThroughJoinEnabled.toString, + // ObjectSerializerPruning produces different scan shapes depending on whether a Filter is + // present. Disabling the rule makes both scans identical so PlanMerger can merge them. + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning") { + val df = sql( + """ + |SELECT + | (SELECT sum(key) FROM testData JOIN testData2 ON key = a), + | (SELECT sum(key) FROM testData JOIN testData2 ON key = a WHERE b > 1) + """.stripMargin) + + checkAnswer(df, Row(12, 6) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + if (filterPropagationThroughJoinEnabled) { + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 1, + "Missing or unexpected ReusedSubqueryExec in the plan") + } else { + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 0, + "Missing or unexpected ReusedSubqueryExec in the plan") + } + } + } + } + } }