diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index a85fad2f758c1..f02ddd1f2d769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -252,74 +252,12 @@ class PlanMerger( filterPropagationSupported: Boolean): Option[TryMergeResult] = { checkIdenticalPlans(newPlan, cachedPlan).map(TryMergeResult(cachedPlan, _)).orElse( (newPlan, cachedPlan) match { - case (np: Project, cp: Project) => - tryMergePlans(np.child, cp.child, filterPropagationSupported).map { - case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => - val (mergedProjectList, newNPMapping) = - mergeNamedExpressions(np.projectList, cp.projectList, npMapping, npFilter, cpFilter) - TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, npFilter, - cpFilter) - } - case (np, cp: Project) => - tryMergePlans(np, cp.child, filterPropagationSupported).map { - case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => - val (mergedProjectList, newNPMapping) = - mergeNamedExpressions(np.output, cp.projectList, npMapping, npFilter, cpFilter) - TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, npFilter, - cpFilter) - } - case (np: Project, cp) => - tryMergePlans(np.child, cp, filterPropagationSupported).map { - case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => - val (mergedProjectList, newNPMapping) = - mergeNamedExpressions(np.projectList, cp.output, npMapping, npFilter, cpFilter) - TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, npFilter, - cpFilter) - } - - case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) => - // Filter propagation into the aggregate is only safe when there is no grouping. - val childFilterPropagationSupported = filterPropagationEnabled && - np.groupingExpressions.isEmpty && cp.groupingExpressions.isEmpty - tryMergePlans(np.child, cp.child, childFilterPropagationSupported).flatMap { - case TryMergeResult(mergedChild, npMapping, None, None) => - val mappedNPGroupingExpression = - np.groupingExpressions.map(mapAttributes(_, npMapping)) - // Order of grouping expression does matter as merging different grouping orders can - // introduce "extra" shuffles/sorts that might not present in all of the original - // subqueries. - if (mappedNPGroupingExpression.map(_.canonicalized) == - cp.groupingExpressions.map(_.canonicalized)) { - val (mergedAggregateExpressions, newNPMapping) = - mergeNamedExpressions(np.aggregateExpressions, cp.aggregateExpressions, npMapping) - val mergedPlan = - Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild) - Some(TryMergeResult(mergedPlan, newNPMapping)) - } else { - None - } - case TryMergeResult(mergedChild, npMapping, npFilterOpt, cpFilterOpt) => - // childFilterPropagationSupported guarantees both aggregates have no grouping, so - // the grouping-match check is skipped. - assert(childFilterPropagationSupported) - - // Apply each propagated boolean attribute as a FILTER (WHERE ...) clause on the - // corresponding side's aggregate expressions. - // A None filter means the side's aggregate expressions already carry their individual - // FILTER attributes from a previous merge round and should be left unchanged. - // Filter propagation is consumed here and not passed further up. - val filteredNPAggregateExpressions = npFilterOpt.fold(np.aggregateExpressions) { - case (f, _) => applyFilterToAggregateExpressions(np.aggregateExpressions, f) - } - val filteredCPAggregateExpressions = cpFilterOpt.fold(cp.aggregateExpressions)( - applyFilterToAggregateExpressions(cp.aggregateExpressions, _)) - val (mergedAggregateExpressions, newNPMapping) = - mergeNamedExpressions(filteredNPAggregateExpressions, - filteredCPAggregateExpressions, npMapping) - val mergedPlan = Aggregate(Seq.empty, mergedAggregateExpressions, mergedChild) - Some(TryMergeResult(mergedPlan, newNPMapping)) - } - + // Filter cases must precede the generic Project-peeling cases below. + // When filterPropagationSupported is true, a (Filter, Project) pair must be handled here so + // that the reuse check can find an already-aliased condition in the merged child Project. + // If (np, cp: Project) fired first, it would peel the Project layer and recurse with + // (Filter, ...), where no Project exists yet, causing a redundant alias to be created + // instead of reusing the existing one. case (np: Filter, cp: Filter) => tryMergePlans(np.child, cp.child, filterPropagationSupported).flatMap { case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => @@ -406,13 +344,28 @@ class PlanMerger( val newNPCondition = npFilter.fold(mappedNPCondition) { case (f, _) => And(f, mappedNPCondition) } - val newNPFilterAlias = - Alias(newNPCondition, s"propagatedFilter_${PlanMerger.newId}")() - val newNPFilter = newNPFilterAlias.toAttribute - val project = Project( - mergedChild.output.toList ++ Seq(newNPFilterAlias) ++ cpFilter.toSeq, - mergedChild) - TryMergeResult(project, npMapping, Some((newNPFilter, true)), cpFilter) + // If newNPCondition is already aliased in the child Project (e.g. a subsequent + // subplan whose filter matches one already propagated in a previous round), reuse + // the existing attribute instead of creating a redundant alias. + val existingNPFilter = mergedChild match { + case p: Project => p.projectList.collectFirst { + case a: Alias if a.child.canonicalized == newNPCondition.canonicalized => + a.toAttribute + } + case _ => None + } + existingNPFilter match { + case Some(reusedFilter) => + TryMergeResult(mergedChild, npMapping, Some((reusedFilter, false)), cpFilter) + case None => + val newNPFilterAlias = + Alias(newNPCondition, s"propagatedFilter_${PlanMerger.newId}")() + val newNPFilter = newNPFilterAlias.toAttribute + val project = Project( + mergedChild.output.toList ++ Seq(newNPFilterAlias) ++ cpFilter.toSeq, + mergedChild) + TryMergeResult(project, npMapping, Some((newNPFilter, true)), cpFilter) + } } case (np, cp: Filter) if filterPropagationSupported => tryMergePlans(np, cp.child, filterPropagationSupported).collect { @@ -431,6 +384,74 @@ class PlanMerger( TryMergeResult(project, npMapping, npFilter, Some(newCPFilter)) } + case (np: Project, cp: Project) => + tryMergePlans(np.child, cp.child, filterPropagationSupported).map { + case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => + val (mergedProjectList, newNPMapping) = + mergeNamedExpressions(np.projectList, cp.projectList, npMapping, npFilter, cpFilter) + TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, npFilter, + cpFilter) + } + case (np, cp: Project) => + tryMergePlans(np, cp.child, filterPropagationSupported).map { + case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => + val (mergedProjectList, newNPMapping) = + mergeNamedExpressions(np.output, cp.projectList, npMapping, npFilter, cpFilter) + TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, npFilter, + cpFilter) + } + case (np: Project, cp) => + tryMergePlans(np.child, cp, filterPropagationSupported).map { + case TryMergeResult(mergedChild, npMapping, npFilter, cpFilter) => + val (mergedProjectList, newNPMapping) = + mergeNamedExpressions(np.projectList, cp.output, npMapping, npFilter, cpFilter) + TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, npFilter, + cpFilter) + } + + case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) => + // Filter propagation into the aggregate is only safe when there is no grouping. + val childFilterPropagationSupported = filterPropagationEnabled && + np.groupingExpressions.isEmpty && cp.groupingExpressions.isEmpty + tryMergePlans(np.child, cp.child, childFilterPropagationSupported).flatMap { + case TryMergeResult(mergedChild, npMapping, None, None) => + val mappedNPGroupingExpression = + np.groupingExpressions.map(mapAttributes(_, npMapping)) + // Order of grouping expression does matter as merging different grouping orders can + // introduce "extra" shuffles/sorts that might not present in all of the original + // subqueries. + if (mappedNPGroupingExpression.map(_.canonicalized) == + cp.groupingExpressions.map(_.canonicalized)) { + val (mergedAggregateExpressions, newNPMapping) = + mergeNamedExpressions(np.aggregateExpressions, cp.aggregateExpressions, npMapping) + val mergedPlan = + Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild) + Some(TryMergeResult(mergedPlan, newNPMapping)) + } else { + None + } + case TryMergeResult(mergedChild, npMapping, npFilterOpt, cpFilterOpt) => + // childFilterPropagationSupported guarantees both aggregates have no grouping, so + // the grouping-match check is skipped. + assert(childFilterPropagationSupported) + + // Apply each propagated boolean attribute as a FILTER (WHERE ...) clause on the + // corresponding side's aggregate expressions. + // A None filter means the side's aggregate expressions already carry their individual + // FILTER attributes from a previous merge round and should be left unchanged. + // Filter propagation is consumed here and not passed further up. + val filteredNPAggregateExpressions = npFilterOpt.fold(np.aggregateExpressions) { + case (f, _) => applyFilterToAggregateExpressions(np.aggregateExpressions, f) + } + val filteredCPAggregateExpressions = cpFilterOpt.fold(cp.aggregateExpressions)( + applyFilterToAggregateExpressions(cp.aggregateExpressions, _)) + val (mergedAggregateExpressions, newNPMapping) = + mergeNamedExpressions(filteredNPAggregateExpressions, + filteredCPAggregateExpressions, npMapping) + val mergedPlan = Aggregate(Seq.empty, mergedAggregateExpressions, mergedChild) + Some(TryMergeResult(mergedPlan, newNPMapping)) + } + case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => tryMergePlans(np.left, cp.left, filterPropagationSupported).flatMap { case TryMergeResult(mergedLeft, leftNPMapping, leftNPFilter, leftCPFilter) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index c9e5db5810597..61c603f4842f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -853,6 +853,42 @@ class MergeSubplansSuite extends PlanTest { } } + test("SPARK-56703: Merge three non-grouping subqueries where the third has the same filter " + + "condition as the second") { + // subquery1 has no filter (cached as cp), subquery2 has filter a > 1 (np, propagates f0 + // via the one-sided (np: Filter, cp) path), subquery3 has the identical filter a > 1. + // After subquery1+2 are merged, f0 = (a > 1) is already aliased in the merged child Project. + // When subquery3 is merged it should reuse f0 rather than creating a redundant f1. + val subquery1 = ScalarSubquery(testRelation.groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(min($"a").as("min_a"))) + val subquery3 = ScalarSubquery( + testRelation.where($"a" > 1).groupBy()(count($"a").as("count_a"))) + val originalQuery = testRelation.select(subquery1, subquery2, subquery3) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .groupBy()( + max($"a").as("max_a"), + min($"a", Some(f0)).as("min_a"), + count($"a", Some(f0)).as("count_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a", + Literal("count_a"), $"count_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1), + extractorExpression(0, analyzedMergedSubquery.output, 2)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + test("SPARK-40193: Do not merge non-grouping subqueries with different filter conditions when " + "disabled") { withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED.key -> "false") {