From e7f777f29295115c61ccc8488af2cbe983c2344b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 11 Nov 2025 20:43:35 +0100 Subject: [PATCH 1/3] [SPARK-44571][SQL] Merge subplans with one row result --- .../expressions/BloomFilterMightContain.scala | 2 + .../optimizer/MergeScalarSubqueries.scala | 241 ----------- .../catalyst/optimizer/MergeSubplans.scala | 373 ++++++++++++++++++ .../sql/catalyst/optimizer/PlanMerger.scala | 13 +- .../sql/catalyst/trees/TreePatterns.scala | 1 + ...esSuite.scala => MergeSubplansSuite.scala} | 135 ++++++- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../spark/sql/InjectRuntimeFilterSuite.scala | 4 +- .../org/apache/spark/sql/SubquerySuite.scala | 98 +++++ 9 files changed, 620 insertions(+), 251 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{MergeScalarSubqueriesSuite.scala => MergeSubplansSuite.scala} (82%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index 784bea899c4c..e3ff7c5f05f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.optimizer.ScalarSubqueryReference import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.BloomFilter @@ -58,6 +59,7 @@ case class BloomFilterMightContain( case GetStructField(subquery: PlanExpression[_], _, _) if !subquery.containsPattern(OUTER_REFERENCE) => TypeCheckResult.TypeCheckSuccess + case _: ScalarSubqueryReference => TypeCheckResult.TypeCheckSuccess case _ => DataTypeMismatch( errorSubClass = "BLOOM_FILTER_BINARY_OP_WRONG_TYPE", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala deleted file mode 100644 index 45b8437bad05..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, Project, Subquery, WithCTE} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType - -/** - * This rule tries to merge multiple non-correlated [[ScalarSubquery]]s to compute multiple scalar - * values once. - * - * The process is the following: - * - While traversing through the plan each [[ScalarSubquery]] plan is tried to merge into already - * seen subquery plans using `PlanMerger`s. - * During this first traversal each [[ScalarSubquery]] expression is replaced to a temporal - * [[ScalarSubqueryReference]] pointing to its possible merged version stored in `PlanMerger`s. - * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more plans, or is an - * original unmerged plan. [[ScalarSubqueryReference]]s contain all the required information to - * either restore the original [[ScalarSubquery]] or create a reference to a merged CTE. - * - Once the first traversal is complete and all possible merging have been done a second traversal - * removes the [[ScalarSubqueryReference]]s to either restore the original [[ScalarSubquery]] or - * to replace the original to a modified one that references a CTE with a merged plan. - * A modified [[ScalarSubquery]] is constructed like: - * `GetStructField(ScalarSubquery(CTERelationRef(...)), outputIndex)` where `outputIndex` is the - * index of the output attribute (of the CTE) that corresponds to the output of the original - * subquery. - * - If there are merged subqueries in `PlanMerger`s then a `WithCTE` node is built from these - * queries. The `CTERelationDef` nodes contain the merged subquery in the following form: - * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubqueryPlan)`. - * The definitions are flagged that they host a subquery, that can return maximum one row. - * - * Eg. the following query: - * - * SELECT - * (SELECT avg(a) FROM t), - * (SELECT sum(b) FROM t) - * - * is optimized from: - * - * == Optimized Logical Plan == - * Project [scalar-subquery#242 [] AS scalarsubquery()#253, - * scalar-subquery#243 [] AS scalarsubquery()#254L] - * : :- Aggregate [avg(a#244) AS avg(a)#247] - * : : +- Project [a#244] - * : : +- Relation default.t[a#244,b#245] parquet - * : +- Aggregate [sum(a#251) AS sum(a)#250L] - * : +- Project [a#251] - * : +- Relation default.t[a#251,b#252] parquet - * +- OneRowRelation - * - * to: - * - * == Optimized Logical Plan == - * Project [scalar-subquery#242 [].avg(a) AS scalarsubquery()#253, - * scalar-subquery#243 [].sum(a) AS scalarsubquery()#254L] - * : :- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] - * : : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L] - * : : +- Project [a#244] - * : : +- Relation default.t[a#244,b#245] parquet - * : +- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] - * : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L] - * : +- Project [a#244] - * : +- Relation default.t[a#244,b#245] parquet - * +- OneRowRelation - * - * == Physical Plan == - * *(1) Project [Subquery scalar-subquery#242, [id=#125].avg(a) AS scalarsubquery()#253, - * ReusedSubquery - * Subquery scalar-subquery#242, [id=#125].sum(a) AS scalarsubquery()#254L] - * : :- Subquery scalar-subquery#242, [id=#125] - * : : +- *(2) Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] - * : : +- *(2) HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)], - * output=[avg(a)#247, sum(a)#250L]) - * : : +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#120] - * : : +- *(1) HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)], - * output=[sum#262, count#263L, sum#264L]) - * : : +- *(1) ColumnarToRow - * : : +- FileScan parquet default.t[a#244] ... - * : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125] - * +- *(1) Scan OneRowRelation[] - */ -object MergeScalarSubqueries extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan match { - // Subquery reuse needs to be enabled for this optimization. - case _ if !conf.getConf(SQLConf.SUBQUERY_REUSE_ENABLED) => plan - - // This rule does a whole plan traversal, no need to run on subqueries. - case _: Subquery => plan - - // Plans with CTEs are not supported for now. - case _: WithCTE => plan - - case _ => extractCommonScalarSubqueries(plan) - } - } - - private def extractCommonScalarSubqueries(plan: LogicalPlan) = { - // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert references in place of - // `ScalarSubquery`s. - val planMergers = ArrayBuffer.empty[PlanMerger] - val planWithReferences = insertReferences(plan, planMergers)._1 - - // Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged - // ones. While traversing replace references in plans back to `CTERelationRef`s or to original - // `ScalarSubquery`s. This is safe as a subquery plan at a level can reference only lower level - // other subqueries. - val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]] - planMergers.foreach { planMerger => - val mergedPlans = planMerger.mergedPlans() - subqueryPlansByLevel += mergedPlans.map { mergedPlan => - val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) { - // Level 0 plans can't contain references - mergedPlan.plan - } else { - removeReferences(mergedPlan.plan, subqueryPlansByLevel) - } - if (mergedPlan.merged && mergedPlan.plan.output.size > 1) { - CTERelationDef( - Project( - Seq(Alias( - CreateNamedStruct( - planWithoutReferences.output.flatMap(a => Seq(Literal(a.name), a))), - "mergedValue")()), - planWithoutReferences), - underSubquery = true) - } else { - planWithoutReferences - } - } - } - - // Replace references back to `CTERelationRef`s or to original `ScalarSubquery`s in the main - // plan. - val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel) - - // Add `CTERelationDef`s to the plan. - val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte }) - if (subqueryCTEs.nonEmpty) { - WithCTE(newPlan, subqueryCTEs.toSeq) - } else { - newPlan - } - } - - // First traversal inserts `ScalarSubqueryReference`s to the plan and tries to merge subquery - // plans by each level. - private def insertReferences( - plan: LogicalPlan, - planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = { - // The level of a subquery plan is maximum level of its inner subqueries + 1 or 0 if it has no - // inner subqueries. - var maxLevel = 0 - val planWithReferences = - plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) { - case s: ScalarSubquery if !s.isCorrelated && s.deterministic => - val (planWithReferences, level) = insertReferences(s.plan, planMergers) - - while (level >= planMergers.size) planMergers += new PlanMerger() - // The subquery could contain a hint that is not propagated once we merge it, but as a - // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. - val mergeResult = planMergers(level).merge(planWithReferences) - - maxLevel = maxLevel.max(level + 1) - - val mergedOutput = mergeResult.outputMap(planWithReferences.output.head) - val outputIndex = - mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == mergedOutput.exprId) - ScalarSubqueryReference( - level, - mergeResult.mergedPlanIndex, - outputIndex, - s.dataType, - s.exprId) - case o => o - } - (planWithReferences, maxLevel) - } - - // Second traversal replaces `ScalarSubqueryReference`s to either - // `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the plan is merged from - // multiple subqueries or `ScalarSubquery(original plan)` if it isn't. - private def removeReferences( - plan: LogicalPlan, - subqueryPlansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = { - plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { - case ssr: ScalarSubqueryReference => - subqueryPlansByLevel(ssr.level)(ssr.mergedPlanIndex) match { - case cte: CTERelationDef => - GetStructField( - ScalarSubquery( - CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming), - exprId = ssr.exprId), - ssr.outputIndex) - case o => ScalarSubquery(o, exprId = ssr.exprId) - } - } - } -} - -/** - * Temporal reference to a subquery which is added to a `PlanMerger`. - * - * @param level The level of the replaced subquery. It defines the `PlanMerger` instance into which - * the subquery is merged. - * @param mergedPlanIndex The index of the merged plan in the `PlanMerger`. - * @param outputIndex The index of the output attribute of the merged plan. - * @param dataType The dataType of original scalar subquery. - * @param exprId The expression id of the original scalar subquery. - */ -case class ScalarSubqueryReference( - level: Int, - mergedPlanIndex: Int, - outputIndex: Int, - dataType: DataType, - exprId: ExprId) extends LeafExpression with Unevaluable { - override def nullable: Boolean = true - - final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala new file mode 100644 index 000000000000..e132b1ad2180 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, LeafNode, LogicalPlan, OneRowRelation, Project, Subquery, WithCTE} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * This rule tries to merge multiple subplans that have one row result. This can be either the plan + * tree of a [[ScalarSubquery]] expression or the plan tree starting at a non-grouping [[Aggregate]] + * node. + * + * The process is the following: + * - While traversing through the plan each one row returning subplan is tried to merge into already + * seen one row returning subplans using `PlanMerger`s. + * During this first traversal each [[ScalarSubquery]] expression is replaced to a temporal + * [[ScalarSubqueryReference]] and each non-grouping [[Aggregate]] node is replaced to a temporal + * [[NonGroupingAggregateReference]] pointing to its possible merged version in `PlanMerger`s. + * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more subplans, or is an + * original unmerged plan. + * [[ScalarSubqueryReference]]s and [[NonGroupingAggregateReference]]s contain all the required + * information to either restore the original subplan or create a reference to a merged CTE. + * - Once the first traversal is complete and all possible merging have been done, a second + * traversal removes the references to either restore the original subplans or to replace the + * original to a modified ones that reference a CTE with a merged plan. + * A modified [[ScalarSubquery]] is constructed like: + * `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)` + * ans a modified [[Aggregate]] is constructed like: + * ``` + * Project( + * Seq( + * GetStructField( + * ScalarSubquery(CTERelationRef to the merged plan), + * merged output index 1), + * GetStructField( + * ScalarSubquery(CTERelationRef to the merged plan), + * merged output index 2), + * ...), + * OneRowRelation) + * ``` + * where `merged output index`s are the index of the output attributes (of the CTE) that + * correspond to the output of the original node. + * - If there are merged subqueries in `PlanMerger`s then a `WithCTE` node is built from these + * queries. The `CTERelationDef` nodes contain the merged subplans in the following form: + * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubplan)`. + * The definitions are flagged that they host a subplan, that can return maximum one row. + * + * Here are a few examples: + * + * 1. a query with 2 subqueries: + * ``` + * Project [scalar-subquery [] AS scalarsubquery(), scalar-subquery [] AS scalarsubquery()] + * : :- Aggregate [min(a) AS min(a)] + * : : +- Relation [a, b, c] + * : +- Aggregate [sum(b) AS sum(b)] + * : +- Relation [a, b, c] + * +- OneRowRelation + * ``` + * is optimized to: + * ``` + * WithCTE + * :- CTERelationDef 0 + * : +- Project [named_struct(min(a), min(a), sum(b), sum(b)) AS mergedValue] + * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b)] + * : +- Relation [a, b, c] + * +- Project [scalar-subquery [].min(a) AS scalarsubquery(), + * scalar-subquery [].sum(b) AS scalarsubquery()] + * : :- CTERelationRef 0 + * : +- CTERelationRef 0 + * +- OneRowRelation + * ``` + * + * 2. a query with 2 non-grouping aggregates: + * ``` + * Join Inner + * :- Aggregate [min(a) AS min(a)] + * : +- Relation [a, b, c] + * +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] + * +- Relation [a, b, c] + * ``` + * is optimized to: + * ``` + * WithCTE + * :- CTERelationDef 0 + * : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue] + * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] + * : +- Relation [a, b, c] + * +- Join Inner + * :- Project [scalar-subquery [].min(a) AS min(a)] + * : : +- CTERelationRef 0 + * : +- OneRowRelation + * +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)] + * : :- CTERelationRef 0 + * : +- CTERelationRef 0 + * +- OneRowRelation + * ``` + * + * 3. a query with a subquery and a non-grouping aggregate: + * ``` + * Join Inner + * :- Project [scalar-subquery [] AS scalarsubquery()] + * : : +- Aggregate [min(a) AS min(a)] + * : : +- Relation [a, b, c] + * : +- OneRowRelation + * +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] + * +- Relation [a, b, c] + * ``` + * is optimized to: + * ``` + * WithCTE + * :- CTERelationDef 0 + * : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue] + * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] + * : +- Relation [a, b, c] + * +- Join Inner + * :- Project [scalar-subquery [].min(a) AS scalarsubquery()] + * : : +- CTERelationRef 0 + * : +- OneRowRelation + * +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)] + * : :- CTERelationRef 0 + * : +- CTERelationRef 0 + * +- OneRowRelation + * ``` + */ +object MergeSubplans extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + // Subquery reuse needs to be enabled for this optimization. + case _ if !conf.getConf(SQLConf.SUBQUERY_REUSE_ENABLED) => plan + + // This rule does a whole plan traversal, no need to run on subqueries. + case _: Subquery => plan + + // Plans with CTEs are not supported for now. + case _: WithCTE => plan + + case _ => extractCommonScalarSubqueries(plan) + } + } + + private def extractCommonScalarSubqueries(plan: LogicalPlan) = { + // Collect subplans by level into `PlanMerger`s and insert references in place of them. + val planMergers = ArrayBuffer.empty[PlanMerger] + val planWithReferences = insertReferences(plan, true, planMergers)._1 + + // Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged + // ones. While traversing replace references in plans back to `CTERelationRef`s or to original + // plans. This is safe as a subplan at a level can reference only lower level ot other subplans. + val subplansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]] + planMergers.foreach { planMerger => + val mergedPlans = planMerger.mergedPlans() + subplansByLevel += mergedPlans.map { mergedPlan => + val planWithoutReferences = if (subplansByLevel.isEmpty) { + // Level 0 plans can't contain references + mergedPlan.plan + } else { + removeReferences(mergedPlan.plan, subplansByLevel) + } + if (mergedPlan.merged) { + CTERelationDef( + Project( + Seq(Alias( + CreateNamedStruct( + planWithoutReferences.output.flatMap(a => Seq(Literal(a.name), a))), + "mergedValue")()), + planWithoutReferences), + underSubquery = true) + } else { + planWithoutReferences + } + } + } + + // Replace references back to `CTERelationRef`s or to original subplans. + val newPlan = removeReferences(planWithReferences, subplansByLevel) + + // Add `CTERelationDef`s to the plan. + val subplanCTEs = subplansByLevel.flatMap(_.collect { case cte: CTERelationDef => cte }) + if (subplanCTEs.nonEmpty) { + WithCTE(newPlan, subplanCTEs.toSeq) + } else { + newPlan + } + } + + // First traversal inserts `ScalarSubqueryReference`s and `NoGroupingAggregateReference`s to the + // plan and tries to merge subplans by each level. Levels are separated eiter by scalar subqueries + // or by non-grouping aggregate nodes. Nodes with the same level make sense to try merging. + private def insertReferences( + plan: LogicalPlan, + root: Boolean, + planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = { + if (!plan.containsAnyPattern(AGGREGATE, SCALAR_SUBQUERY)) { + return (plan, 0) + } + + // Calculate the level propagated from subquery plans, which is the maximum level of the + // subqueries of the node + 1 or 0 if the node has no subqueries. + var levelFromSubqueries = 0 + val nodeSubqueriesWithReferences = + plan.transformExpressionsWithPruning(_.containsPattern(SCALAR_SUBQUERY)) { + case s: ScalarSubquery if !s.isCorrelated && s.deterministic => + val (planWithReferences, level) = insertReferences(s.plan, true, planMergers) + + // The subquery could contain a hint that is not propagated once we merge it, but as a + // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. + val mergeResult = getPlanMerger(planMergers, level).merge(planWithReferences, true) + + levelFromSubqueries = levelFromSubqueries.max(level + 1) + + val mergedOutput = mergeResult.outputMap(planWithReferences.output.head) + val outputIndex = + mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == mergedOutput.exprId) + ScalarSubqueryReference( + level, + mergeResult.mergedPlanIndex, + outputIndex, + s.dataType, + s.exprId) + case o => o + } + + // Calculate the level of the node, which is the maximum of the above calculated level + // propagated from subqueries and the level propagated from child nodes. + val (planWithReferences, level) = nodeSubqueriesWithReferences match { + case a: Aggregate if !root && a.groupingExpressions.isEmpty => + val (childWithReferences, levelFromChild) = insertReferences(a.child, false, planMergers) + val aggregateWithReferences = a.withNewChildren(Seq(childWithReferences)) + + // Level is the maximum of the level from subqueries and the level from child. + val level = levelFromChild.max(levelFromSubqueries) + + val mergeResult = getPlanMerger(planMergers, level).merge(aggregateWithReferences, false) + + val mergedOutput = aggregateWithReferences.output.map(mergeResult.outputMap) + val outputIndices = + mergedOutput.map(a => mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == a.exprId)) + val aggregateReference = NonGroupingAggregateReference( + level, + mergeResult.mergedPlanIndex, + outputIndices, + a.output + ) + + // This is a non-grouping aggregate node so propagate the level of the node + 1 to its + // parent + (aggregateReference, level + 1) + case o => + val (newChildren, levels) = o.children.map(insertReferences(_, false, planMergers)).unzip + // Level is the maximum of the level from subqueries and the level from the children. + (o.withNewChildren(newChildren), (levelFromSubqueries +: levels).max) + } + + (planWithReferences, level) + } + + private def getPlanMerger(planMergers: ArrayBuffer[PlanMerger], level: Int) = { + while (level >= planMergers.size) planMergers += new PlanMerger() + planMergers(level) + } + + // Second traversal replaces: + // - a `ScalarSubqueryReference` either to + // `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)` if + // the plan is merged from multiple subqueries or to `ScalarSubquery(original plan)` if it + // isn't. + // - a `NoGroupingAggregateReference` either to + // ``` + // Project( + // Seq( + // GetStructField( + // ScalarSubquery(CTERelationRef to the merged plan), + // merged output index 1), + // GetStructField( + // ScalarSubquery(CTERelationRef to the merged plan), + // merged output index 2), + // ...), + // OneRowRelation) + // ``` + // if the plan is merged from multiple subqueries or to `original plan` if it isn't. + private def removeReferences( + plan: LogicalPlan, + subplansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = { + plan.transformUpWithPruning( + _.containsAnyPattern(NO_GROUPING_AGGREGATE_REFERENCE, SCALAR_SUBQUERY_REFERENCE)) { + case ngar: NonGroupingAggregateReference => + subplansByLevel(ngar.level)(ngar.mergedPlanIndex) match { + case cte: CTERelationDef => + val projectList = ngar.outputIndices.zip(ngar.output).map { case (i, a) => + Alias( + GetStructField( + ScalarSubquery( + CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming)), + i), + a.name)(a.exprId) + } + Project(projectList, OneRowRelation()) + case o => o + } + case o => o.transformExpressionsUpWithPruning(_.containsPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + subplansByLevel(ssr.level)(ssr.mergedPlanIndex) match { + case cte: CTERelationDef => + GetStructField( + ScalarSubquery( + CTERelationRef(cte.id, _resolved = true, cte.output, cte.isStreaming), + exprId = ssr.exprId), + ssr.outputIndex) + case o => ScalarSubquery(o, exprId = ssr.exprId) + } + } + } + } +} + +/** + * Temporal reference to a subquery which is added to a `PlanMerger`. + * + * @param level The level of the replaced subquery. It defines the `PlanMerger` instance into which + * the subquery is merged. + * @param mergedPlanIndex The index of the merged plan in the `PlanMerger`. + * @param outputIndex The index of the output attribute of the merged plan. + * @param dataType The dataType of original scalar subquery. + * @param exprId The expression id of the original scalar subquery. + */ +case class ScalarSubqueryReference( + level: Int, + mergedPlanIndex: Int, + outputIndex: Int, + override val dataType: DataType, + exprId: ExprId) extends LeafExpression with Unevaluable { + override def nullable: Boolean = true + + final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE) +} + +/** + * Temporal reference to a non-grouping aggregate which is added to a `PlanMerger`. + * + * @param level The level of the replaced aggregate. It defines the `PlanMerger` instance into which + * the aggregate is merged. + * @param mergedPlanIndex The index of the merged plan in the `PlanMerger`. + * @param outputIndices The indices of the output attributes of the merged plan. + * @param output The output of original aggregate. + */ +case class NonGroupingAggregateReference( + level: Int, + mergedPlanIndex: Int, + outputIndices: Seq[Int], + override val output: Seq[Attribute]) extends LeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(NO_GROUPING_AGGREGATE_REFERENCE) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index 37982d163927..1623166e0a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -58,8 +58,8 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean) * 2. Merge a new plan with a cached plan by combining their outputs * * The merging process preserves semantic equivalence while combining outputs from multiple - * plans into a single plan. This is primarily used by [[MergeScalarSubqueries]] to deduplicate - * scalar subquery execution. + * plans into a single plan. This is primarily used by [[MergeSubplans]] to deduplicate subplan + * execution. * * Supported plan types for merging: * - [[Project]]: Merges project lists @@ -88,16 +88,21 @@ class PlanMerger { * 3. If no merge is possible, add as a new cache entry * * @param plan The logical plan to merge or cache. + * @param subqueryPlan If the logical plan is a subquery plan. * @return A [[MergeResult]] containing: * - The merged/cached plan to use * - Its index in the cache * - An attribute mapping for rewriting expressions */ - def merge(plan: LogicalPlan): MergeResult = { + def merge(plan: LogicalPlan, subqueryPlan: Boolean): MergeResult = { cache.zipWithIndex.collectFirst(Function.unlift { case (mp, i) => checkIdenticalPlans(plan, mp.plan).map { outputMap => - val newMergePlan = MergedPlan(mp.plan, true) + // Identical subquery expression plans are not marked as `merged` as the + // `ReusedSubqueryExec` rule can handle them without extracting the plans to CTEs. + // But, when a non-subquery subplan is identical to a cached plan we need to mark the plan + // `merged` and so extract it to a CTE later. + val newMergePlan = MergedPlan(mp.plan, cache(i).merged || !subqueryPlan) cache(i) = newMergePlan MergeResult(newMergePlan, i, outputMap) }.orElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index ba4e801ed0a6..5ea93e74c5d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -150,6 +150,7 @@ object TreePattern extends Enumeration { val LOCAL_RELATION: Value = Value val LOGICAL_QUERY_STAGE: Value = Value val NATURAL_LIKE_JOIN: Value = Value + val NO_GROUPING_AGGREGATE_REFERENCE: Value = Value val OFFSET: Value = Value val OUTER_JOIN: Value = Value val PARAMETERIZED_QUERY: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala similarity index 82% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index 008b4a89ce60..b368035e278e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -25,14 +25,14 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class MergeScalarSubqueriesSuite extends PlanTest { +class MergeSubplansSuite extends PlanTest { override def beforeEach(): Unit = { CTERelationDef.curId.set(0) } private object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries) :: Nil + val batches = Batch("MergeSubplans", Once, MergeSubplans) :: Nil } val testRelation = LocalRelation($"a".int, $"b".int, $"c".string) @@ -590,4 +590,135 @@ class MergeScalarSubqueriesSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) } + + test("Merge aggregates") { + val agg1 = testRelation.groupBy()(min($"a").as("min_a")) + val agg2 = testRelation.groupBy()(max($"a").as("max_a")) + val originalQuery = agg1.join(agg2) + + val mergedSubquery = testRelation + .groupBy()( + min($"a").as("min_a"), + max($"a").as("max_a") + ) + .select( + CreateNamedStruct(Seq( + Literal("min_a"), $"min_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + OneRowRelation().select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a")) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 1, "max_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merge non-siblig aggregates") { + val agg1 = testRelation.groupBy()(min($"a").as("min_a")) + val agg2 = testRelation.groupBy()(max($"a").as("max_a")) + val originalQuery = agg1.join(testRelation).join(agg2) + + val mergedSubquery = testRelation + .groupBy()( + min($"a").as("min_a"), + max($"a").as("max_a") + ) + .select( + CreateNamedStruct(Seq( + Literal("min_a"), $"min_a", + Literal("max_a"), $"max_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + OneRowRelation().select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a")) + .join(testRelation) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 1, "max_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merge subqueries and aggregates") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a"))) + val subquery2 = ScalarSubquery(testRelation.groupBy()(max($"a").as("max_a"))) + val agg1 = testRelation.groupBy()(sum($"a").as("sum_a")) + val agg2 = testRelation.groupBy()(avg($"a").as("avg_a")) + val originalQuery = + testRelation + .select( + subquery1, + subquery2) + .join(agg1) + .join(agg2) + + val mergedSubquery = testRelation + .groupBy()( + min($"a").as("min_a"), + max($"a").as("max_a"), + sum($"a").as("sum_a"), + avg($"a").as("avg_a") + ) + .select( + CreateNamedStruct(Seq( + Literal("min_a"), $"min_a", + Literal("max_a"), $"max_a", + Literal("sum_a"), $"sum_a", + Literal("avg_a"), $"avg_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation + .select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 2, "sum_a"))) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 3, "avg_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merge identical subqueries and aggregates") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a"))) + val subquery2 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a_2"))) + val agg1 = testRelation.groupBy()(min($"a").as("min_a_3")) + val agg2 = testRelation.groupBy()(min($"a").as("min_a_4")) + val originalQuery = + testRelation + .select( + subquery1, + subquery2) + .join(agg1) + .join(agg2) + + val mergedSubquery = testRelation + .groupBy()(min($"a").as("min_a")) + .select( + CreateNamedStruct(Seq(Literal("min_a"), $"min_a")).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation + .select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 0)) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a_3"))) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 0, "min_a_4"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8edb59f49282..7f3b8383f0f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -61,8 +61,8 @@ class SparkOptimizer( new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)), Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter), - Batch("MergeScalarSubqueries", Once, - MergeScalarSubqueries, + Batch("MergeSubplans", Once, + MergeSubplans, RewriteDistinctAggregates), Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 7d7185ae6c13..603ec183bfb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} -import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries +import org.apache.spark.sql.catalyst.optimizer.MergeSubplans import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} @@ -207,7 +207,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp // `MergeScalarSubqueries` can duplicate subqueries in the optimized plan and would make testing // complicated. - conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeScalarSubqueries.ruleName) + conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeSubplans.ruleName) } protected override def afterAll(): Unit = try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 3ba48da0e327..2ce67b5bed3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2389,6 +2389,104 @@ class SubquerySuite extends QueryTest } } + test("Merge non-grouping aggregates") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT * + |FROM (SELECT avg(key) FROM testData) + |JOIN (SELECT sum(key) FROM testData) + |JOIN (SELECT count(distinct key) FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregates from different levels") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | first(avg_key), + | ( + | -- Using `testData2` makes the whole subquery plan non-mergeable to the + | -- non-grouping aggregate subplan in the main plan, which uses `testData`, but its + | -- aggregate subplan with `sum(key)` is mergeable + | SELECT first(sum_key) + | FROM (SELECT sum(key) AS sum_key FROM testData) + | JOIN testData2 + | ), + | first(count_key) + |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) + |JOIN testData3 + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregate and subquery") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | first(avg_key), + | ( + | -- In this case the whole scalar subquery plan is mergeable to the non-grouping + | -- aggregate subplan in the main plan. + | SELECT sum(key) AS sum_key FROM testData + | ), + | first(count_key) + |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) + |JOIN testData3 + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + test("SPARK-39355: Single column uses quoted to construct UnresolvedAttribute") { checkAnswer( sql(""" From dafbd64f2b00d009c13cd7bf7e5b587d735caf58 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Nov 2025 13:20:52 +0100 Subject: [PATCH 2/3] move SQL tests to separate `PlanMergeSuite` --- .../org/apache/spark/sql/PlanMergeSuite.scala | 342 ++++++++++++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 311 ---------------- 2 files changed, 342 insertions(+), 311 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala new file mode 100644 index 000000000000..b7557b42702e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class PlanMergeSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + import testImplicits._ + + setupTestData() + + test("Merge non-correlated scalar subqueries") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | (SELECT avg(key) FROM testData), + | (SELECT sum(key) FROM testData), + | (SELECT count(distinct key) FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries in a subquery") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT ( + | SELECT + | SUM( + | (SELECT avg(key) FROM testData) + + | (SELECT sum(key) FROM testData) + + | (SELECT count(distinct key) FROM testData)) + | FROM testData + |) + """.stripMargin) + + checkAnswer(df, Row(520050.0) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 5, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries from different levels") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | (SELECT avg(key) FROM testData), + | ( + | SELECT + | SUM( + | (SELECT sum(key) FROM testData) + | ) + | FROM testData + | ) + """.stripMargin) + + checkAnswer(df, Row(50.5, 505000) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries from different parent plans") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | ( + | SELECT + | SUM( + | (SELECT avg(key) FROM testData) + | ) + | FROM testData + | ), + | ( + | SELECT + | SUM( + | (SELECT sum(key) FROM testData) + | ) + | FROM testData + | ) + """.stripMargin) + + checkAnswer(df, Row(5050.0, 505000) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 4, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-correlated scalar subqueries with conflicting names") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | (SELECT avg(key) AS key FROM testData), + | (SELECT sum(key) AS key FROM testData), + | (SELECT count(distinct key) AS key FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregates") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT * + |FROM (SELECT avg(key) FROM testData) + |JOIN (SELECT sum(key) FROM testData) + |JOIN (SELECT count(distinct key) FROM testData) + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregates from different levels") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | first(avg_key), + | ( + | -- Using `testData2` makes the whole subquery plan non-mergeable to the + | -- non-grouping aggregate subplan in the main plan, which uses `testData`, but its + | -- aggregate subplan with `sum(key)` is mergeable + | SELECT first(sum_key) + | FROM (SELECT sum(key) AS sum_key FROM testData) + | JOIN testData2 + | ), + | first(count_key) + |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) + |JOIN testData3 + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("Merge non-grouping aggregate and subquery") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { + val df = sql( + """ + |SELECT + | first(avg_key), + | ( + | -- In this case the whole scalar subquery plan is mergeable to the non-grouping + | -- aggregate subplan in the main plan. + | SELECT sum(key) AS sum_key FROM testData + | ), + | first(count_key) + |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) + |JOIN testData3 + """.stripMargin) + + checkAnswer(df, Row(50.5, 5050, 100) :: Nil) + + val plan = df.queryExecution.executedPlan + + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 2, + "Missing or unexpected reused ReusedSubqueryExec in the plan") + } + } + } + + test("SPARK-40618: Regression test for merging subquery bug with nested subqueries") { + // This test contains a subquery expression with another subquery expression nested inside. + // It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt + // to merge them together. + withTable("t1", "t2") { + sql("create table t1(col int) using csv") + checkAnswer(sql("select(select sum((select sum(col) from t1)) from t1)"), Row(null)) + + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t1)) + | from t1)) + | from t1) + |""".stripMargin), + Row(null)) + + sql("create table t2(col int) using csv") + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t1)) + | from t2)) + | from t1) + |""".stripMargin), + Row(null)) + } + } + + test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") { + withTempView("t1") { + Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1") + + checkAnswer(sql( + """ + |SELECT + | (SELECT count(distinct c1) FROM t1), + | (SELECT count(distinct c2) FROM t1) + |""".stripMargin), + Row(2, 2)) + + // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the + // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that + // are inserted by the rewrite. + checkAnswer(sql( + """ + |SELECT + | (SELECT count(distinct c1) + sum(distinct c2) FROM t1), + | (SELECT count(distinct c2) + sum(distinct c1) FROM t1) + |""".stripMargin), + Row(8, 6)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 2ce67b5bed3d..b53610761d04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2234,259 +2234,6 @@ class SubquerySuite extends QueryTest } } - test("Merge non-correlated scalar subqueries") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | (SELECT avg(key) FROM testData), - | (SELECT sum(key) FROM testData), - | (SELECT count(distinct key) FROM testData) - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries in a subquery") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT ( - | SELECT - | SUM( - | (SELECT avg(key) FROM testData) + - | (SELECT sum(key) FROM testData) + - | (SELECT count(distinct key) FROM testData)) - | FROM testData - |) - """.stripMargin) - - checkAnswer(df, Row(520050.0) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 5, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries from different levels") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | (SELECT avg(key) FROM testData), - | ( - | SELECT - | SUM( - | (SELECT sum(key) FROM testData) - | ) - | FROM testData - | ) - """.stripMargin) - - checkAnswer(df, Row(50.5, 505000) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries from different parent plans") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | ( - | SELECT - | SUM( - | (SELECT avg(key) FROM testData) - | ) - | FROM testData - | ), - | ( - | SELECT - | SUM( - | (SELECT sum(key) FROM testData) - | ) - | FROM testData - | ) - """.stripMargin) - - checkAnswer(df, Row(5050.0, 505000) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 4, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-correlated scalar subqueries with conflicting names") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | (SELECT avg(key) AS key FROM testData), - | (SELECT sum(key) AS key FROM testData), - | (SELECT count(distinct key) AS key FROM testData) - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-grouping aggregates") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT * - |FROM (SELECT avg(key) FROM testData) - |JOIN (SELECT sum(key) FROM testData) - |JOIN (SELECT count(distinct key) FROM testData) - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-grouping aggregates from different levels") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | first(avg_key), - | ( - | -- Using `testData2` makes the whole subquery plan non-mergeable to the - | -- non-grouping aggregate subplan in the main plan, which uses `testData`, but its - | -- aggregate subplan with `sum(key)` is mergeable - | SELECT first(sum_key) - | FROM (SELECT sum(key) AS sum_key FROM testData) - | JOIN testData2 - | ), - | first(count_key) - |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) - |JOIN testData3 - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - - test("Merge non-grouping aggregate and subquery") { - Seq(false, true).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { - val df = sql( - """ - |SELECT - | first(avg_key), - | ( - | -- In this case the whole scalar subquery plan is mergeable to the non-grouping - | -- aggregate subplan in the main plan. - | SELECT sum(key) AS sum_key FROM testData - | ), - | first(count_key) - |FROM (SELECT avg(key) AS avg_key, count(distinct key) as count_key FROM testData) - |JOIN testData3 - """.stripMargin) - - checkAnswer(df, Row(50.5, 5050, 100) :: Nil) - - val plan = df.queryExecution.executedPlan - - val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } - val reusedSubqueryIds = collectWithSubqueries(plan) { - case rs: ReusedSubqueryExec => rs.child.id - } - - assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } - } - } - test("SPARK-39355: Single column uses quoted to construct UnresolvedAttribute") { checkAnswer( sql(""" @@ -2587,39 +2334,6 @@ class SubquerySuite extends QueryTest } } - test("SPARK-40618: Regression test for merging subquery bug with nested subqueries") { - // This test contains a subquery expression with another subquery expression nested inside. - // It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt - // to merge them together. - withTable("t1", "t2") { - sql("create table t1(col int) using csv") - checkAnswer(sql("select(select sum((select sum(col) from t1)) from t1)"), Row(null)) - - checkAnswer(sql( - """ - |select - | (select sum( - | (select sum( - | (select sum(col) from t1)) - | from t1)) - | from t1) - |""".stripMargin), - Row(null)) - - sql("create table t2(col int) using csv") - checkAnswer(sql( - """ - |select - | (select sum( - | (select sum( - | (select sum(col) from t1)) - | from t2)) - | from t1) - |""".stripMargin), - Row(null)) - } - } - test("SPARK-40615: Check unsupported data type when decorrelating subqueries") { withTempView("v1", "v2") { sql( @@ -2714,31 +2428,6 @@ class SubquerySuite extends QueryTest } } - test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") { - withTempView("t1") { - Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1") - - checkAnswer(sql( - """ - |SELECT - | (SELECT count(distinct c1) FROM t1), - | (SELECT count(distinct c2) FROM t1) - |""".stripMargin), - Row(2, 2)) - - // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the - // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that - // are inserted by the rewrite. - checkAnswer(sql( - """ - |SELECT - | (SELECT count(distinct c1) + sum(distinct c2) FROM t1), - | (SELECT count(distinct c2) + sum(distinct c1) FROM t1) - |""".stripMargin), - Row(8, 6)) - } - } - test("SPARK-42745: Improved AliasAwareOutputExpression works with DSv2") { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "") { From 335b59352c67fcfa94500218f1b36107c26ef0f6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Nov 2025 13:30:50 +0100 Subject: [PATCH 3/3] revert `MergeSubplans` docs change to simlify review --- .../catalyst/optimizer/MergeSubplans.scala | 167 +++++++----------- 1 file changed, 62 insertions(+), 105 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala index e132b1ad2180..5ba64360ffc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala @@ -27,121 +27,78 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType /** - * This rule tries to merge multiple subplans that have one row result. This can be either the plan - * tree of a [[ScalarSubquery]] expression or the plan tree starting at a non-grouping [[Aggregate]] - * node. + * This rule tries to merge multiple non-correlated [[ScalarSubquery]]s to compute multiple scalar + * values once. * * The process is the following: - * - While traversing through the plan each one row returning subplan is tried to merge into already - * seen one row returning subplans using `PlanMerger`s. + * - While traversing through the plan each [[ScalarSubquery]] plan is tried to merge into already + * seen subquery plans using `PlanMerger`s. * During this first traversal each [[ScalarSubquery]] expression is replaced to a temporal - * [[ScalarSubqueryReference]] and each non-grouping [[Aggregate]] node is replaced to a temporal - * [[NonGroupingAggregateReference]] pointing to its possible merged version in `PlanMerger`s. - * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more subplans, or is an - * original unmerged plan. - * [[ScalarSubqueryReference]]s and [[NonGroupingAggregateReference]]s contain all the required - * information to either restore the original subplan or create a reference to a merged CTE. - * - Once the first traversal is complete and all possible merging have been done, a second - * traversal removes the references to either restore the original subplans or to replace the - * original to a modified ones that reference a CTE with a merged plan. + * [[ScalarSubqueryReference]] pointing to its possible merged version stored in `PlanMerger`s. + * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more plans, or is an + * original unmerged plan. [[ScalarSubqueryReference]]s contain all the required information to + * either restore the original [[ScalarSubquery]] or create a reference to a merged CTE. + * - Once the first traversal is complete and all possible merging have been done a second traversal + * removes the [[ScalarSubqueryReference]]s to either restore the original [[ScalarSubquery]] or + * to replace the original to a modified one that references a CTE with a merged plan. * A modified [[ScalarSubquery]] is constructed like: - * `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)` - * ans a modified [[Aggregate]] is constructed like: - * ``` - * Project( - * Seq( - * GetStructField( - * ScalarSubquery(CTERelationRef to the merged plan), - * merged output index 1), - * GetStructField( - * ScalarSubquery(CTERelationRef to the merged plan), - * merged output index 2), - * ...), - * OneRowRelation) - * ``` - * where `merged output index`s are the index of the output attributes (of the CTE) that - * correspond to the output of the original node. + * `GetStructField(ScalarSubquery(CTERelationRef(...)), outputIndex)` where `outputIndex` is the + * index of the output attribute (of the CTE) that corresponds to the output of the original + * subquery. * - If there are merged subqueries in `PlanMerger`s then a `WithCTE` node is built from these - * queries. The `CTERelationDef` nodes contain the merged subplans in the following form: - * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubplan)`. - * The definitions are flagged that they host a subplan, that can return maximum one row. + * queries. The `CTERelationDef` nodes contain the merged subquery in the following form: + * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubqueryPlan)`. + * The definitions are flagged that they host a subquery, that can return maximum one row. * - * Here are a few examples: + * Eg. the following query: * - * 1. a query with 2 subqueries: - * ``` - * Project [scalar-subquery [] AS scalarsubquery(), scalar-subquery [] AS scalarsubquery()] - * : :- Aggregate [min(a) AS min(a)] - * : : +- Relation [a, b, c] - * : +- Aggregate [sum(b) AS sum(b)] - * : +- Relation [a, b, c] + * SELECT + * (SELECT avg(a) FROM t), + * (SELECT sum(b) FROM t) + * + * is optimized from: + * + * == Optimized Logical Plan == + * Project [scalar-subquery#242 [] AS scalarsubquery()#253, + * scalar-subquery#243 [] AS scalarsubquery()#254L] + * : :- Aggregate [avg(a#244) AS avg(a)#247] + * : : +- Project [a#244] + * : : +- Relation default.t[a#244,b#245] parquet + * : +- Aggregate [sum(a#251) AS sum(a)#250L] + * : +- Project [a#251] + * : +- Relation default.t[a#251,b#252] parquet * +- OneRowRelation - * ``` - * is optimized to: - * ``` - * WithCTE - * :- CTERelationDef 0 - * : +- Project [named_struct(min(a), min(a), sum(b), sum(b)) AS mergedValue] - * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b)] - * : +- Relation [a, b, c] - * +- Project [scalar-subquery [].min(a) AS scalarsubquery(), - * scalar-subquery [].sum(b) AS scalarsubquery()] - * : :- CTERelationRef 0 - * : +- CTERelationRef 0 - * +- OneRowRelation - * ``` * - * 2. a query with 2 non-grouping aggregates: - * ``` - * Join Inner - * :- Aggregate [min(a) AS min(a)] - * : +- Relation [a, b, c] - * +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] - * +- Relation [a, b, c] - * ``` - * is optimized to: - * ``` - * WithCTE - * :- CTERelationDef 0 - * : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue] - * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] - * : +- Relation [a, b, c] - * +- Join Inner - * :- Project [scalar-subquery [].min(a) AS min(a)] - * : : +- CTERelationRef 0 - * : +- OneRowRelation - * +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)] - * : :- CTERelationRef 0 - * : +- CTERelationRef 0 - * +- OneRowRelation - * ``` + * to: + * + * == Optimized Logical Plan == + * Project [scalar-subquery#242 [].avg(a) AS scalarsubquery()#253, + * scalar-subquery#243 [].sum(a) AS scalarsubquery()#254L] + * : :- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] + * : : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L] + * : : +- Project [a#244] + * : : +- Relation default.t[a#244,b#245] parquet + * : +- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] + * : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L] + * : +- Project [a#244] + * : +- Relation default.t[a#244,b#245] parquet + * +- OneRowRelation * - * 3. a query with a subquery and a non-grouping aggregate: - * ``` - * Join Inner - * :- Project [scalar-subquery [] AS scalarsubquery()] - * : : +- Aggregate [min(a) AS min(a)] - * : : +- Relation [a, b, c] - * : +- OneRowRelation - * +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] - * +- Relation [a, b, c] - * ``` - * is optimized to: - * ``` - * WithCTE - * :- CTERelationDef 0 - * : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue] - * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)] - * : +- Relation [a, b, c] - * +- Join Inner - * :- Project [scalar-subquery [].min(a) AS scalarsubquery()] - * : : +- CTERelationRef 0 - * : +- OneRowRelation - * +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)] - * : :- CTERelationRef 0 - * : +- CTERelationRef 0 - * +- OneRowRelation - * ``` + * == Physical Plan == + * *(1) Project [Subquery scalar-subquery#242, [id=#125].avg(a) AS scalarsubquery()#253, + * ReusedSubquery + * Subquery scalar-subquery#242, [id=#125].sum(a) AS scalarsubquery()#254L] + * : :- Subquery scalar-subquery#242, [id=#125] + * : : +- *(2) Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] + * : : +- *(2) HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)], + * output=[avg(a)#247, sum(a)#250L]) + * : : +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#120] + * : : +- *(1) HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)], + * output=[sum#262, count#263L, sum#264L]) + * : : +- *(1) ColumnarToRow + * : : +- FileScan parquet default.t[a#244] ... + * : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125] + * +- *(1) Scan OneRowRelation[] */ object MergeSubplans extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = {