diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 6c0bca0e1104f..6466912d79874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule @@ -52,16 +53,26 @@ object NormalizeCTEIds extends Rule[LogicalPlan] { private def canonicalizeCTE( plan: LogicalPlan, - defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { - plan.transformDownWithSubqueries { - // For nested WithCTE, if defIndex didn't contain the cteId, - // means it's not current WithCTE's ref. - case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => - ref.copy(cteId = defIdToNewId(ref.cteId)) - case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) => - unionLoop.copy(id = defIdToNewId(unionLoop.id)) - case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => - unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) - } + defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = plan match { + // Nested WithCTEs are normalized separately by applyInternal. + case _: WithCTE => plan + case other => + val normalizedPlan = other match { + case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => + ref.copy(cteId = defIdToNewId(ref.cteId)) + case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) => + unionLoop.copy(id = defIdToNewId(unionLoop.id)) + case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => + unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) + case _ => + other + } + + normalizedPlan + .withNewChildren(normalizedPlan.children.map(canonicalizeCTE(_, defIdToNewId))) + .transformExpressionsDown { + case subqueryExpression: SubqueryExpression => + subqueryExpression.withNewPlan(canonicalizeCTE(subqueryExpression.plan, defIdToNewId)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 7562d5669cc2c..27cbbc16c0b3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -261,6 +261,45 @@ abstract class CTEInlineSuiteBase } } + test("SPARK-56921: plan normalization handles nested CTEs under union") { + withTempView("input", "common") { + Seq((1, 1, 10), (1, 2, 20), (2, 1, 30)) + .toDF("a", "b", "value") + .createOrReplaceTempView("input") + + sql( + s"""with cte_common as ( + | select a, b, sum(value) as value + | from input + | group by a, b + |) + |select * from cte_common + """.stripMargin).createOrReplaceTempView("common") + + val left = sql( + s"""with cte_a as ( + | select a, sum(value) as value + | from common + | group by a + |) + |select a as id, value from cte_a + """.stripMargin) + + val right = sql( + s"""with cte_b as ( + | select b, sum(value) as value + | from common + | group by b + |) + |select b as id, value from cte_b + """.stripMargin) + + checkAnswer( + left.union(right), + Row(1, 30) :: Row(2, 30) :: Row(1, 40) :: Row(2, 20) :: Nil) + } + } + test("SPARK-36447: invalid nested CTEs") { withTempView("t") { Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")