From db4e04a9b0487fa3f458a9a41c8cef6979dc1a1f Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 2 Nov 2025 12:02:03 +0100 Subject: [PATCH 1/3] [SPARK-54136][SQL] Extract plan merging logic from `MergeScalarSubqueries` to `PlanMerger` --- .../optimizer/MergeScalarSubqueries.scala | 382 +++++------------- .../sql/catalyst/optimizer/PlanMerger.scala | 303 ++++++++++++++ .../MergeScalarSubqueriesSuite.scala | 18 +- 3 files changed, 417 insertions(+), 286 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 4f8b2eda92f3..b7db7d73c017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, Filter, Join, LogicalPlan, Project, Subquery, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, Project, Subquery, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern} import org.apache.spark.sql.internal.SQLConf @@ -33,23 +31,24 @@ import org.apache.spark.sql.types.DataType * values once. * * The process is the following: - * - While traversing through the plan each [[ScalarSubquery]] plan is tried to merge into the cache - * of already seen subquery plans. If merge is possible then cache is updated with the merged - * subquery plan, if not then the new subquery plan is added to the cache. + * - While traversing through the plan each [[ScalarSubquery]] plan is tried to merge into already + * seen subquery plans using `PlanMerger`s. * During this first traversal each [[ScalarSubquery]] expression is replaced to a temporal - * [[ScalarSubqueryReference]] reference pointing to its cached version. - * The cache uses a flag to keep track of if a cache entry is a result of merging 2 or more - * plans, or it is a plan that was seen only once. - * Merged plans in the cache get a "Header", that contains the list of attributes form the scalar - * return value of a merged subquery. - * - A second traversal checks if there are merged subqueries in the cache and builds a `WithCTE` - * node from these queries. The `CTERelationDef` nodes contain the merged subquery in the - * following form: - * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubqueryPlan)` - * and the definitions are flagged that they host a subquery, that can return maximum one row. - * During the second traversal [[ScalarSubqueryReference]] expressions that pont to a merged - * subquery is either transformed to a `GetStructField(ScalarSubquery(CTERelationRef(...)))` - * expression or restored to the original [[ScalarSubquery]]. + * [[ScalarSubqueryReference]] pointing to its possible merged version stored in `PlanMerger`s. + * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more plans, or is an + * original unmerged plan. [[ScalarSubqueryReference]]s contain all the required information to + * either restore the original [[ScalarSubquery]] or create a reference to a merged CTE. + * - Once the first traversal is complete and all possible merging have been done a second traversal + * removes the [[ScalarSubqueryReference]]s to either restore the original [[ScalarSubquery]] or + * to replace the original to a modified one that references a CTE with a merged plan. + * A modified [[ScalarSubquery]] is constructed like: + * `GetStructField(ScalarSubquery(CTERelationRef(...)), outputIndex)` where `outputIndex` is the + * index of the output attribute (of the CTE) that corresponds to the output of the original + * subquery. + * - If there are merged subqueries in `PlanMerger`s then a `WithCTE` node is built from these + * queries. The `CTERelationDef` nodes contain the merged subquery in the following form: + * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubqueryPlan)`. + * The definitions are flagged that they host a subquery, that can return maximum one row. * * Eg. the following query: * @@ -117,38 +116,47 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } } - /** - * An item in the cache of merged scalar subqueries. - * - * @param plan The plan of a merged scalar subquery. - * @param merged A flag to identify if this item is the result of merging subqueries. - * Please note that `attributes.size == 1` doesn't always mean that the plan is not - * merged as there can be subqueries that are different ([[checkIdenticalPlans]] is - * false) due to an extra [[Project]] node in one of them. In that case - * `attributes.size` remains 1 after merging, but the merged flag becomes true. - * @param references A set of subquery indexes in the cache to track all (including transitive) - * nested subqueries. - */ - case class Header( - plan: LogicalPlan, - merged: Boolean, - references: Set[Int]) - private def extractCommonScalarSubqueries(plan: LogicalPlan) = { - val cache = ArrayBuffer.empty[Header] - val planWithReferences = insertReferences(plan, cache) - cache.zipWithIndex.foreach { case (header, i) => - cache(i) = cache(i).copy(plan = - if (header.merged) { + // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert references in place of + // `ScalarSubquery`s. + val planMergers = ArrayBuffer.empty[PlanMerger] + val planWithReferences = insertReferences(plan, planMergers)._1 + + // Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged + // ones. While traversing replace references in plans back to `CTERelationRef`s or to original + // `ScalarSubquery`s. This is safe as a subquery plan at a level can reference only lower level + // other subqueries. + val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]] + planMergers.foreach { planMerger => + val mergedPlans = planMerger.mergedPlans() + subqueryPlansByLevel += mergedPlans.map { mergedPlan => + val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) { + // Level 0 plans can't contain references + mergedPlan.plan + } else { + removeReferences(mergedPlan.plan, subqueryPlansByLevel) + } + if (mergedPlan.merged) { CTERelationDef( - createProject(header.plan.output, removeReferences(header.plan, cache)), + Project( + Seq(Alias( + CreateNamedStruct( + planWithoutReferences.output.flatMap(a => Seq(Literal(a.name), a))), + "mergedValue")()), + planWithoutReferences), underSubquery = true) } else { - removeReferences(header.plan, cache) - }) + planWithoutReferences + } + } } - val newPlan = removeReferences(planWithReferences, cache) - val subqueryCTEs = cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef]) + + // Replace references back to `CTERelationRef`s or to original `ScalarSubquery`s in the main + // plan. + val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel) + + // Add `CTERelationDef`s to the plan. + val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte }) if (subqueryCTEs.nonEmpty) { WithCTE(newPlan, subqueryCTEs.toSeq) } else { @@ -156,214 +164,38 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } } - // First traversal builds up the cache and inserts `ScalarSubqueryReference`s to the plan. - private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]): LogicalPlan = { - plan.transformUpWithSubqueries { - case n => n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) { - // The subquery could contain a hint that is not propagated once we cache it, but as a - // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. + // First traversal inserts `ScalarSubqueryReference`s to the plan and tries to merge subquery + // plans by each level. + private def insertReferences( + plan: LogicalPlan, + planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = { + // The level of a subquery plan is maximum level of its inner subqueries + 1 or 0 if it has no + // inner subqueries. + var maxLevel = 0 + val planWithReferences = + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) { case s: ScalarSubquery if !s.isCorrelated && s.deterministic => - val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache) - ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType, s.exprId) - } - } - } - - // Caching returns the index of the subquery in the cache and the index of scalar member in the - // "Header". - private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = { - val output = plan.output.head - val references = mutable.HashSet.empty[Int] - plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { - case ssr: ScalarSubqueryReference => - references += ssr.subqueryIndex - references ++= cache(ssr.subqueryIndex).references - ssr - } - - cache.zipWithIndex.collectFirst(Function.unlift { - case (header, subqueryIndex) if !references.contains(subqueryIndex) => - checkIdenticalPlans(plan, header.plan).map { outputMap => - val mappedOutput = mapAttributes(output, outputMap) - val headerIndex = header.plan.output.indexWhere(_.exprId == mappedOutput.exprId) - subqueryIndex -> headerIndex - }.orElse{ - tryMergePlans(plan, header.plan).map { - case (mergedPlan, outputMap) => - val mappedOutput = mapAttributes(output, outputMap) - val headerIndex = mergedPlan.output.indexWhere(_.exprId == mappedOutput.exprId) - cache(subqueryIndex) = Header(mergedPlan, true, header.references ++ references) - subqueryIndex -> headerIndex - } - } - case _ => None - }).getOrElse { - cache += Header(plan, false, references.toSet) - cache.length - 1 -> 0 - } - } - - // If 2 plans are identical return the attribute mapping from the new to the cached version. - private def checkIdenticalPlans( - newPlan: LogicalPlan, - cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = { - if (newPlan.canonicalized == cachedPlan.canonicalized) { - Some(AttributeMap(newPlan.output.zip(cachedPlan.output))) - } else { - None - } - } - - // Recursively traverse down and try merging 2 plans. If merge is possible then return the merged - // plan with the attribute mapping from the new to the merged version. - // Please note that merging arbitrary plans can be complicated, the current version supports only - // some of the most important nodes. - private def tryMergePlans( - newPlan: LogicalPlan, - cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { - checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse( - (newPlan, cachedPlan) match { - case (np: Project, cp: Project) => - tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) => - val (mergedProjectList, newOutputMap) = - mergeNamedExpressions(np.projectList, outputMap, cp.projectList) - val mergedPlan = Project(mergedProjectList, mergedChild) - mergedPlan -> newOutputMap - } - case (np, cp: Project) => - tryMergePlans(np, cp.child).map { case (mergedChild, outputMap) => - val (mergedProjectList, newOutputMap) = - mergeNamedExpressions(np.output, outputMap, cp.projectList) - val mergedPlan = Project(mergedProjectList, mergedChild) - mergedPlan -> newOutputMap - } - case (np: Project, cp) => - tryMergePlans(np.child, cp).map { case (mergedChild, outputMap) => - val (mergedProjectList, newOutputMap) = - mergeNamedExpressions(np.projectList, outputMap, cp.output) - val mergedPlan = Project(mergedProjectList, mergedChild) - mergedPlan -> newOutputMap - } - case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) => - tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => - val mappedNewGroupingExpression = - np.groupingExpressions.map(mapAttributes(_, outputMap)) - // Order of grouping expression does matter as merging different grouping orders can - // introduce "extra" shuffles/sorts that might not present in all of the original - // subqueries. - if (mappedNewGroupingExpression.map(_.canonicalized) == - cp.groupingExpressions.map(_.canonicalized)) { - val (mergedAggregateExpressions, newOutputMap) = - mergeNamedExpressions(np.aggregateExpressions, outputMap, cp.aggregateExpressions) - val mergedPlan = - Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild) - Some(mergedPlan -> newOutputMap) - } else { - None - } - } - - case (np: Filter, cp: Filter) => - tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => - val mappedNewCondition = mapAttributes(np.condition, outputMap) - // Comparing the canonicalized form is required to ignore different forms of the same - // expression. - if (mappedNewCondition.canonicalized == cp.condition.canonicalized) { - val mergedPlan = cp.withNewChildren(Seq(mergedChild)) - Some(mergedPlan -> outputMap) - } else { - None - } - } - - case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => - tryMergePlans(np.left, cp.left).flatMap { case (mergedLeft, leftOutputMap) => - tryMergePlans(np.right, cp.right).flatMap { case (mergedRight, rightOutputMap) => - val outputMap = leftOutputMap ++ rightOutputMap - val mappedNewCondition = np.condition.map(mapAttributes(_, outputMap)) - // Comparing the canonicalized form is required to ignore different forms of the same - // expression and `AttributeReference.quailifier`s in `cp.condition`. - if (mappedNewCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)) { - val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) - Some(mergedPlan -> outputMap) - } else { - None - } - } - } - - // Otherwise merging is not possible. - case _ => None - }) - } - - private def createProject(attributes: Seq[Attribute], plan: LogicalPlan): Project = { - Project( - Seq(Alias( - CreateNamedStruct(attributes.flatMap(a => Seq(Literal(a.name), a))), - "mergedValue")()), - plan) - } - - private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { - expr.transform { - case a: Attribute => outputMap.getOrElse(a, a) - }.asInstanceOf[T] - } - - // Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into - // `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to - // the merged version that can be propagated up during merging nodes. - private def mergeNamedExpressions( - newExpressions: Seq[NamedExpression], - outputMap: AttributeMap[Attribute], - cachedExpressions: Seq[NamedExpression]) = { - val mergedExpressions = ArrayBuffer[NamedExpression](cachedExpressions: _*) - val newOutputMap = AttributeMap(newExpressions.map { ne => - val mapped = mapAttributes(ne, outputMap) - val withoutAlias = mapped match { - case Alias(child, _) => child - case e => e - } - ne.toAttribute -> mergedExpressions.find { - case Alias(child, _) => child semanticEquals withoutAlias - case e => e semanticEquals withoutAlias - }.getOrElse { - mergedExpressions += mapped - mapped - }.toAttribute - }) - (mergedExpressions.toSeq, newOutputMap) - } - - // Only allow aggregates of the same implementation because merging different implementations - // could cause performance regression. - private def supportedAggregateMerge(newPlan: Aggregate, cachedPlan: Aggregate) = { - val aggregateExpressionsSeq = Seq(newPlan, cachedPlan).map { plan => - plan.aggregateExpressions.flatMap(_.collect { - case a: AggregateExpression => a - }) - } - val groupByExpressionSeq = Seq(newPlan, cachedPlan).map(_.groupingExpressions) - - val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) = - aggregateExpressionsSeq.zip(groupByExpressionSeq).map { - case (aggregateExpressions, groupByExpressions) => - Aggregate.supportsHashAggregate( - aggregateExpressions.flatMap( - _.aggregateFunction.aggBufferAttributes), groupByExpressions) - } - - newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate || - newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && { - val Seq(newPlanSupportsObjectHashAggregate, cachedPlanSupportsObjectHashAggregate) = - aggregateExpressionsSeq.zip(groupByExpressionSeq).map { - case (aggregateExpressions, groupByExpressions) => - Aggregate.supportsObjectHashAggregate(aggregateExpressions, groupByExpressions) - } - newPlanSupportsObjectHashAggregate && cachedPlanSupportsObjectHashAggregate || - newPlanSupportsObjectHashAggregate == cachedPlanSupportsObjectHashAggregate + val (planWithReferences, level) = insertReferences(s.plan, planMergers) + + while (level >= planMergers.size) planMergers += PlanMerger() + // The subquery could contain a hint that is not propagated once we merge it, but as a + // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. + val planMergeResult = planMergers(level).merge(planWithReferences) + + maxLevel = maxLevel.max(level + 1) + + val mergedOutput = planMergeResult.outputMap(planWithReferences.output.head) + val headerIndex = + planMergeResult.mergedPlan.output.indexWhere(_.exprId == mergedOutput.exprId) + ScalarSubqueryReference( + level, + planMergeResult.mergedPlanIndex, + headerIndex, + s.dataType, + s.exprId) + case o => o } + (planWithReferences, maxLevel) } // Second traversal replaces `ScalarSubqueryReference`s to either @@ -371,43 +203,39 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { // multiple subqueries or `ScalarSubquery(original plan)` if it isn't. private def removeReferences( plan: LogicalPlan, - cache: ArrayBuffer[Header]) = { - plan.transformUpWithSubqueries { - case n => - n.transformExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { - case ssr: ScalarSubqueryReference => - val header = cache(ssr.subqueryIndex) - if (header.merged) { - val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] - GetStructField( - ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, - subqueryCTE.isStreaming), - exprId = ssr.exprId), - ssr.headerIndex) - } else { - ScalarSubquery(header.plan, exprId = ssr.exprId) - } + subqueryPlansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = { + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + subqueryPlansByLevel(ssr.level)(ssr.mergedPlanIndex) match { + case cte: CTERelationDef => + GetStructField( + ScalarSubquery( + CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming), + exprId = ssr.exprId), + ssr.outputIndex) + case o => ScalarSubquery(o, exprId = ssr.exprId) } } } } /** - * Temporal reference to a cached subquery. + * Temporal reference to a subquery which is added to a `PlanMerger`. * - * @param subqueryIndex A subquery index in the cache. - * @param headerIndex An index in the output of merged subquery. - * @param dataType The dataType of origin scalar subquery. + * @param level The level of the replaced subquery. It defines the `PlanMerger` instance into which + * the subquery is merged. + * @param mergedPlanIndex The index of the merged plan in the `PlanMerger`. + * @param outputIndex The index of the output attribute of the merged plan. + * @param dataType The dataType of original scalar subquery. + * @param exprId The expression id of the original scalar subquery. */ case class ScalarSubqueryReference( - subqueryIndex: Int, - headerIndex: Int, + level: Int, + mergedPlanIndex: Int, + outputIndex: Int, dataType: DataType, exprId: ExprId) extends LeafExpression with Unevaluable { override def nullable: Boolean = true final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE) - - override def stringArgs: Iterator[Any] = Iterator(subqueryIndex, headerIndex, dataType, exprId.id) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala new file mode 100644 index 000000000000..638aa6627e99 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project} + +/** + * Result of attempting to merge a plan via [[PlanMerger.merge]]. + * + * @param mergedPlan The resulting plan, either: + * - An existing cached plan (if identical match found) + * - A newly merged plan combining the input with a cached plan + * - The original input plan (if no merge was possible) + * @param mergedPlanIndex The index of this plan in the PlanMerger's cache. + * @param merged Whether the plan was merged with an existing cached plan (true) or + * is a new entry (false). + * @param outputMap Maps attributes from the input plan to corresponding attributes in + * `mergedPlan`. Used to rewrite expressions referencing the original plan + * to reference the merged plan instead. + */ +case class MergeResult( + mergedPlan: LogicalPlan, + mergedPlanIndex: Int, + merged: Boolean, + outputMap: AttributeMap[Attribute]) + +/** + * Represents a plan in the PlanMerger's cache. + * + * @param plan The logical plan, which may have been merged from multiple original plans. + * @param merged Whether this plan is the result of merging two or more plans (true), or + * is an original unmerged plan (false). Merged plans typically require special + * handling such as wrapping in CTEs. + */ +case class MergedPlan(plan: LogicalPlan, merged: Boolean) + +/** + * A stateful utility for merging identical or similar logical plans to enable query plan reuse. + * + * `PlanMerger` maintains a cache of previously seen plans and attempts to either: + * 1. Reuse an identical plan already in the cache + * 2. Merge a new plan with a cached plan by combining their outputs + * + * The merging process preserves semantic equivalence while combining outputs from multiple + * plans into a single plan. This is primarily used by [[MergeScalarSubqueries]] to deduplicate + * scalar subquery execution. + * + * Supported plan types for merging: + * - [[Project]]: Merges project lists + * - [[Aggregate]]: Merges aggregate expressions with identical grouping + * - [[Filter]]: Requires identical filter conditions + * - [[Join]]: Requires identical join type, hints, and conditions + * + * @example + * {{{ + * val merger = PlanMerger() + * val result1 = merger.merge(plan1) // Adds plan1 to cache + * val result2 = merger.merge(plan2) // Merges with plan1 if compatible + * // result2.merged == true if plans were merged + * // result2.outputMap maps plan2's attributes to the merged plan's attributes + * }}} + */ +class PlanMerger { + val cache = ArrayBuffer.empty[MergedPlan] + + /** + * Attempts to merge the given plan with cached plans, or adds it to the cache. + * + * The method tries the following in order: + * 1. Check if an identical plan exists in cache (using canonicalized comparison) + * 2. Try to merge with each cached plan using [[tryMergePlans]] + * 3. If no merge is possible, add as a new cache entry + * + * @param plan The logical plan to merge or cache. + * @return A [[MergeResult]] containing: + * - The merged/cached plan to use + * - Its index in the cache + * - Whether it was merged with an existing plan + * - An attribute mapping for rewriting expressions + */ + def merge(plan: LogicalPlan): MergeResult = { + cache.zipWithIndex.collectFirst(Function.unlift { + case (mp, i) => + checkIdenticalPlans(plan, mp.plan).map { outputMap => + MergeResult(mp.plan, i, true, outputMap) + }.orElse { + tryMergePlans(plan, mp.plan).map { + case (mergedPlan, outputMap) => + cache(i) = MergedPlan(mergedPlan, true) + MergeResult(mergedPlan, i, true, outputMap) + } + } + case _ => None + }).getOrElse { + cache += MergedPlan(plan, false) + val outputMap = AttributeMap(plan.output.map(a => a -> a)) + MergeResult(plan, cache.length - 1, false, outputMap) + } + } + + /** + * Returns all plans currently in the cache as an immutable indexed sequence. + * + * @return An indexed sequence of [[MergedPlan]]s in cache order. The index of each plan + * corresponds to the `mergedPlanIndex` returned by [[merge]]. + */ + def mergedPlans(): IndexedSeq[MergedPlan] = cache.toIndexedSeq + + // If 2 plans are identical return the attribute mapping from the new to the cached version. + private def checkIdenticalPlans( + newPlan: LogicalPlan, + cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = { + if (newPlan.canonicalized == cachedPlan.canonicalized) { + Some(AttributeMap(newPlan.output.zip(cachedPlan.output))) + } else { + None + } + } + + /** + * Recursively attempts to merge two plans by traversing their tree structures. + * + * Two plans can be merged if: + * - They are identical (canonicalized forms match), OR + * - They have compatible root nodes with mergeable children + * + * Supported merge patterns: + * - Project nodes: Combines project lists from both plans + * - Aggregate nodes: Combines aggregate expressions if grouping is identical and both + * support the same aggregate implementation (hash/object-hash/sort-based) + * - Filter nodes: Only if filter conditions are identical + * - Join nodes: Only if join type, hints, and conditions are identical + * + * @param newPlan The plan to merge into the cached plan. + * @param cachedPlan The cached plan to merge with. + * @return Some((mergedPlan, outputMap)) if merge succeeds, where: + * - mergedPlan is the combined plan + * - outputMap maps newPlan's attributes to mergedPlan's attributes + * Returns None if plans cannot be merged. + */ + private def tryMergePlans( + newPlan: LogicalPlan, + cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { + checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse( + (newPlan, cachedPlan) match { + case (np: Project, cp: Project) => + tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) => + val (mergedProjectList, newOutputMap) = + mergeNamedExpressions(np.projectList, outputMap, cp.projectList) + val mergedPlan = Project(mergedProjectList, mergedChild) + mergedPlan -> newOutputMap + } + case (np, cp: Project) => + tryMergePlans(np, cp.child).map { case (mergedChild, outputMap) => + val (mergedProjectList, newOutputMap) = + mergeNamedExpressions(np.output, outputMap, cp.projectList) + val mergedPlan = Project(mergedProjectList, mergedChild) + mergedPlan -> newOutputMap + } + case (np: Project, cp) => + tryMergePlans(np.child, cp).map { case (mergedChild, outputMap) => + val (mergedProjectList, newOutputMap) = + mergeNamedExpressions(np.projectList, outputMap, cp.output) + val mergedPlan = Project(mergedProjectList, mergedChild) + mergedPlan -> newOutputMap + } + case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) => + tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => + val mappedNewGroupingExpression = + np.groupingExpressions.map(mapAttributes(_, outputMap)) + // Order of grouping expression does matter as merging different grouping orders can + // introduce "extra" shuffles/sorts that might not present in all of the original + // subqueries. + if (mappedNewGroupingExpression.map(_.canonicalized) == + cp.groupingExpressions.map(_.canonicalized)) { + val (mergedAggregateExpressions, newOutputMap) = + mergeNamedExpressions(np.aggregateExpressions, outputMap, cp.aggregateExpressions) + val mergedPlan = + Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild) + Some(mergedPlan -> newOutputMap) + } else { + None + } + } + + case (np: Filter, cp: Filter) => + tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => + val mappedNewCondition = mapAttributes(np.condition, outputMap) + // Comparing the canonicalized form is required to ignore different forms of the same + // expression. + if (mappedNewCondition.canonicalized == cp.condition.canonicalized) { + val mergedPlan = cp.withNewChildren(Seq(mergedChild)) + Some(mergedPlan -> outputMap) + } else { + None + } + } + + case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => + tryMergePlans(np.left, cp.left).flatMap { case (mergedLeft, leftOutputMap) => + tryMergePlans(np.right, cp.right).flatMap { case (mergedRight, rightOutputMap) => + val outputMap = leftOutputMap ++ rightOutputMap + val mappedNewCondition = np.condition.map(mapAttributes(_, outputMap)) + // Comparing the canonicalized form is required to ignore different forms of the same + // expression and `AttributeReference.qualifier`s in `cp.condition`. + if (mappedNewCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)) { + val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) + Some(mergedPlan -> outputMap) + } else { + None + } + } + } + + // Otherwise merging is not possible. + case _ => None + }) + } + + private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { + expr.transform { + case a: Attribute => outputMap.getOrElse(a, a) + }.asInstanceOf[T] + } + + // Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into + // `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to + // the merged version that can be propagated up during merging nodes. + private def mergeNamedExpressions( + newExpressions: Seq[NamedExpression], + outputMap: AttributeMap[Attribute], + cachedExpressions: Seq[NamedExpression]) = { + val mergedExpressions = ArrayBuffer[NamedExpression](cachedExpressions: _*) + val newOutputMap = AttributeMap(newExpressions.map { ne => + val mapped = mapAttributes(ne, outputMap) + val withoutAlias = mapped match { + case Alias(child, _) => child + case e => e + } + ne.toAttribute -> mergedExpressions.find { + case Alias(child, _) => child semanticEquals withoutAlias + case e => e semanticEquals withoutAlias + }.getOrElse { + mergedExpressions += mapped + mapped + }.toAttribute + }) + (mergedExpressions.toSeq, newOutputMap) + } + + // Only allow aggregates of the same implementation because merging different implementations + // could cause performance regression. + private def supportedAggregateMerge(newPlan: Aggregate, cachedPlan: Aggregate) = { + val aggregateExpressionsSeq = Seq(newPlan, cachedPlan).map { plan => + plan.aggregateExpressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + } + val groupByExpressionSeq = Seq(newPlan, cachedPlan).map(_.groupingExpressions) + + val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) = + aggregateExpressionsSeq.zip(groupByExpressionSeq).map { + case (aggregateExpressions, groupByExpressions) => + Aggregate.supportsHashAggregate( + aggregateExpressions.flatMap( + _.aggregateFunction.aggBufferAttributes), groupByExpressions) + } + + newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate || + newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && { + val Seq(newPlanSupportsObjectHashAggregate, cachedPlanSupportsObjectHashAggregate) = + aggregateExpressionsSeq.zip(groupByExpressionSeq).map { + case (aggregateExpressions, groupByExpressions) => + Aggregate.supportsObjectHashAggregate(aggregateExpressions, groupByExpressions) + } + newPlanSupportsObjectHashAggregate && cachedPlanSupportsObjectHashAggregate || + newPlanSupportsObjectHashAggregate == cachedPlanSupportsObjectHashAggregate + } + } +} + +object PlanMerger { + def apply(): PlanMerger = new PlanMerger +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala index 636a280bce00..c167b80c1827 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala @@ -568,14 +568,14 @@ class MergeScalarSubqueriesSuite extends PlanTest { val subquery5 = ScalarSubquery(testRelation.select((Symbol("a") + 2).as("a_plus2_2"))) val subquery6 = ScalarSubquery(testRelation.select(Symbol("b").as("b_2"))) val originalQuery = testRelation - .select( - subquery1, - subquery2, - subquery3) .where( subquery4 + subquery5 + subquery6 === 0) + .select( + subquery1, + subquery2, + subquery3) val mergedSubquery = testRelation .select( @@ -591,14 +591,14 @@ class MergeScalarSubqueriesSuite extends PlanTest { val analyzedMergedSubquery = mergedSubquery.analyze val correctAnswer = WithCTE( testRelation - .select( - extractorExpression(0, analyzedMergedSubquery.output, 0), - extractorExpression(0, analyzedMergedSubquery.output, 1), - extractorExpression(0, analyzedMergedSubquery.output, 2)) .where( extractorExpression(0, analyzedMergedSubquery.output, 0) + extractorExpression(0, analyzedMergedSubquery.output, 1) + - extractorExpression(0, analyzedMergedSubquery.output, 2) === 0), + extractorExpression(0, analyzedMergedSubquery.output, 2) === 0) + .select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1), + extractorExpression(0, analyzedMergedSubquery.output, 2)), Seq(definitionNode(analyzedMergedSubquery, 0))) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) From f6899a0f28e893dab4c2f99b9c82f2c229fb69d9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 7 Nov 2025 12:55:13 +0100 Subject: [PATCH 2/3] fix review findings --- .../optimizer/MergeScalarSubqueries.scala | 12 ++++----- .../sql/catalyst/optimizer/PlanMerger.scala | 26 ++++++++----------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index b7db7d73c017..0ef98fc19c8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -136,7 +136,7 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } else { removeReferences(mergedPlan.plan, subqueryPlansByLevel) } - if (mergedPlan.merged) { + if (mergedPlan.merged && mergedPlan.plan.output.size > 1) { CTERelationDef( Project( Seq(Alias( @@ -177,19 +177,19 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { case s: ScalarSubquery if !s.isCorrelated && s.deterministic => val (planWithReferences, level) = insertReferences(s.plan, planMergers) - while (level >= planMergers.size) planMergers += PlanMerger() + while (level >= planMergers.size) planMergers += new PlanMerger() // The subquery could contain a hint that is not propagated once we merge it, but as a // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. - val planMergeResult = planMergers(level).merge(planWithReferences) + val mergeResult = planMergers(level).merge(planWithReferences) maxLevel = maxLevel.max(level + 1) - val mergedOutput = planMergeResult.outputMap(planWithReferences.output.head) + val mergedOutput = mergeResult.outputMap(planWithReferences.output.head) val headerIndex = - planMergeResult.mergedPlan.output.indexWhere(_.exprId == mergedOutput.exprId) + mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == mergedOutput.exprId) ScalarSubqueryReference( level, - planMergeResult.mergedPlanIndex, + mergeResult.mergedPlanIndex, headerIndex, s.dataType, s.exprId) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index 638aa6627e99..37982d163927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -31,16 +31,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, Log * - A newly merged plan combining the input with a cached plan * - The original input plan (if no merge was possible) * @param mergedPlanIndex The index of this plan in the PlanMerger's cache. - * @param merged Whether the plan was merged with an existing cached plan (true) or - * is a new entry (false). * @param outputMap Maps attributes from the input plan to corresponding attributes in * `mergedPlan`. Used to rewrite expressions referencing the original plan * to reference the merged plan instead. */ case class MergeResult( - mergedPlan: LogicalPlan, + mergedPlan: MergedPlan, mergedPlanIndex: Int, - merged: Boolean, outputMap: AttributeMap[Attribute]) /** @@ -75,7 +72,7 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean) * val merger = PlanMerger() * val result1 = merger.merge(plan1) // Adds plan1 to cache * val result2 = merger.merge(plan2) // Merges with plan1 if compatible - * // result2.merged == true if plans were merged + * // result2.mergedPlan.merged == true if plans were merged * // result2.outputMap maps plan2's attributes to the merged plan's attributes * }}} */ @@ -94,26 +91,29 @@ class PlanMerger { * @return A [[MergeResult]] containing: * - The merged/cached plan to use * - Its index in the cache - * - Whether it was merged with an existing plan * - An attribute mapping for rewriting expressions */ def merge(plan: LogicalPlan): MergeResult = { cache.zipWithIndex.collectFirst(Function.unlift { case (mp, i) => checkIdenticalPlans(plan, mp.plan).map { outputMap => - MergeResult(mp.plan, i, true, outputMap) + val newMergePlan = MergedPlan(mp.plan, true) + cache(i) = newMergePlan + MergeResult(newMergePlan, i, outputMap) }.orElse { tryMergePlans(plan, mp.plan).map { case (mergedPlan, outputMap) => - cache(i) = MergedPlan(mergedPlan, true) - MergeResult(mergedPlan, i, true, outputMap) + val newMergePlan = MergedPlan(mergedPlan, true) + cache(i) = newMergePlan + MergeResult(newMergePlan, i, outputMap) } } case _ => None }).getOrElse { - cache += MergedPlan(plan, false) + val newMergePlan = MergedPlan(plan, false) + cache += newMergePlan val outputMap = AttributeMap(plan.output.map(a => a -> a)) - MergeResult(plan, cache.length - 1, false, outputMap) + MergeResult(newMergePlan, cache.length - 1, outputMap) } } @@ -297,7 +297,3 @@ class PlanMerger { } } } - -object PlanMerger { - def apply(): PlanMerger = new PlanMerger -} From 228c95f9678709769db2ad78012354c4e99af01b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 7 Nov 2025 14:27:01 +0100 Subject: [PATCH 3/3] minor rename --- .../spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 0ef98fc19c8b..45b8437bad05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -185,12 +185,12 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { maxLevel = maxLevel.max(level + 1) val mergedOutput = mergeResult.outputMap(planWithReferences.output.head) - val headerIndex = + val outputIndex = mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == mergedOutput.exprId) ScalarSubqueryReference( level, mergeResult.mergedPlanIndex, - headerIndex, + outputIndex, s.dataType, s.exprId) case o => o