From 347aad37ec5a409fd3a5cc876fe670b2baaa791d Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Fri, 12 Apr 2024 16:39:49 -0700 Subject: [PATCH 1/8] fix aggregate bug --- .../spark/sql/catalyst/expressions/With.scala | 11 ++ .../optimizer/RewriteWithExpression.scala | 61 ++++--- .../spark/sql/catalyst/plans/QueryPlan.scala | 24 +++ .../RewriteWithExpressionSuite.scala | 168 +++++++++++++----- 4 files changed, 204 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index 2745b663639f8..fa72a6fc121d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -88,6 +88,17 @@ object With { val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_)) With(replaced(commonExprRefs), commonExprDefs) } + + /** + * Used for testing to specify the exact common expression ID for each common expression. + */ + def create(commonExprs: (Expression, Long)*)(replaced: Seq[Expression] => Expression): With = { + val commonExprDefs = commonExprs.map { case (expr, exprId) => + CommonExpressionDef(expr, CommonExpressionId(exprId)) + } + val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_)) + With(replaced(commonExprRefs), commonExprDefs) + } } case class CommonExpressionId(id: Long = CommonExpressionId.newId, canonicalized: Boolean = false) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 934eadbcee551..282b9254c8a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -21,36 +21,57 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, Project} +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, PlanHelper, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{ALIAS, COMMON_EXPR_REF, WITH_EXPRESSION} /** * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or * just inline them if they are cheap. * + * Since this rule can introduce new `Project` operators, it is advised to run [[CollapseProject]] + * after this rule. + * * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. If we expand its * usage, we should support aggregate/window functions as well. */ object RewriteWithExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) { + plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) { + case p @ PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child, limit) + if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => + // For aggregates, separate computation of the aggregations themselves from the final + // result by moving the final result computation into a projection above. This prevents + // this rule from producing an invalid Aggregate operator. + // TODO: the names of these aliases will become outdated after the rewrite + val aggExprs = aggregateExpressions.map(ae => Alias(ae, ae.toString)(ae.resultIds.head)) + // Rewrite the projection and the aggregate separately and then piece them together. + val agg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child, limit) + val rewrittenAgg = applyInternal(agg) + val proj = Project(resultExpressions, rewrittenAgg) + applyInternal(proj) case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => - val inputPlans = p.children.toArray - var newPlan: LogicalPlan = p.mapExpressions { expr => - rewriteWithExprAndInputPlans(expr, inputPlans) - } - newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq) - // Since we add extra Projects with extra columns to pre-evaluate the common expressions, - // the current operator may have extra columns if it inherits the output columns from its - // child, and we need to project away the extra columns to keep the plan schema unchanged. - assert(p.output.length <= newPlan.output.length) - if (p.output.length < newPlan.output.length) { - assert(p.outputSet.subsetOf(newPlan.outputSet)) - Project(p.output, newPlan) - } else { - newPlan - } + applyInternal(p) + } + } + + private def applyInternal(p: LogicalPlan): LogicalPlan = { + val inputPlans = p.children.toArray + var newPlan: LogicalPlan = p.mapExpressions { expr => + rewriteWithExprAndInputPlans(expr, inputPlans) + } + newPlan = newPlan.withNewChildren(inputPlans) + // Since we add extra Projects with extra columns to pre-evaluate the common expressions, + // the current operator may have extra columns if it inherits the output columns from its + // child, and we need to project away the extra columns to keep the plan schema unchanged. + assert(p.output.length <= newPlan.output.length) + if (p.output.length < newPlan.output.length) { + assert(p.outputSet.subsetOf(newPlan.outputSet)) + Project(p.output, newPlan) + } else { + newPlan } } @@ -66,7 +87,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] { val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression] val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias]) - defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => + defs.foreach { case CommonExpressionDef(child, id) => if (child.containsPattern(COMMON_EXPR_REF)) { throw SparkException.internalError( "Common expression definition cannot reference other Common expression definitions") @@ -93,7 +114,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] { // if it's ref count is 1. refToExpr(id) = child } else { - val alias = Alias(child, s"_common_expr_$index")() + val alias = Alias(child, s"_common_expr_${id.id}")() val fakeProj = Project(Seq(alias), inputPlans(childProjectionIndex)) if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) { // We have to inline the common expression if it cannot be put in a Project. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0f049103542ec..505330d871cbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -517,6 +517,30 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) } + /** + * Same as `transformUpWithSubqueries` except allows for pruning opportunities. + */ + def transformUpWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId) + (f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = plan.transformExpressionsUpWithPruning(t => + t.containsPattern(PLAN_EXPRESSION) && cond(t)) { + case planExpression: PlanExpression[PlanType@unchecked] => + val newPlan = planExpression.plan.transformUpWithSubqueriesAndPruning(cond, ruleId)(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + transformUpWithPruning(cond, ruleId)(g) + } + /** * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. * Returns a copy of this node where the given partial function has been recursively applied diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index a386e9bf4efe6..5bb8cf1dd3b8a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, CommonExpressionDef, CommonExpressionRef, With} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -29,7 +29,9 @@ import org.apache.spark.sql.types.IntegerType class RewriteWithExpressionSuite extends PlanTest { object Optimizer extends RuleExecutor[LogicalPlan] { - val batches = Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil + val batches = Batch("Rewrite With expression", Once, + PullOutGroupingExpressions, + RewriteWithExpression) :: Nil } private val testRelation = LocalRelation($"a".int, $"b".int) @@ -37,17 +39,19 @@ class RewriteWithExpressionSuite extends PlanTest { test("simple common expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a) - val ref = new CommonExpressionRef(commonExprDef) - val plan = testRelation.select(With(ref + ref, Seq(commonExprDef)).as("col")) + val expr = With.create((a, 0)) { case Seq(ref) => + ref + ref + } + val plan = testRelation.select(expr.as("col")) comparePlans(Optimizer.execute(plan), testRelation.select((a + a).as("col"))) } test("non-cheap common expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val plan = testRelation.select(With(ref * ref, Seq(commonExprDef)).as("col")) + val expr = With.create((a + a, 0)) { case Seq(ref) => + ref * ref + } + val plan = testRelation.select(expr.as("col")) val commonExprName = "_common_expr_0" comparePlans( Optimizer.execute(plan), @@ -60,16 +64,16 @@ class RewriteWithExpressionSuite extends PlanTest { test("nested WITH expression in the definition expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val innerExpr = With(ref + ref, Seq(commonExprDef)) + val innerExpr = With.create((a + a, 0)) { case Seq(ref) => + ref + ref + } val innerCommonExprName = "_common_expr_0" val b = testRelation.output.last - val outerCommonExprDef = CommonExpressionDef(innerExpr + b) - val outerRef = new CommonExpressionRef(outerCommonExprDef) - val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef)) - val outerCommonExprName = "_common_expr_0" + val outerExpr = With.create((innerExpr + b, 1)) { case Seq(ref) => + ref * ref + } + val outerCommonExprName = "_common_expr_1" val plan = testRelation.select(outerExpr.as("col")) val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) @@ -88,16 +92,16 @@ class RewriteWithExpressionSuite extends PlanTest { test("nested WITH expression in the main expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val innerExpr = With(ref + ref, Seq(commonExprDef)) + val innerExpr = With.create((a + a, 0)) { case Seq(ref) => + ref + ref + } val innerCommonExprName = "_common_expr_0" val b = testRelation.output.last - val outerCommonExprDef = CommonExpressionDef(b + b) - val outerRef = new CommonExpressionRef(outerCommonExprDef) - val outerExpr = With(outerRef * outerRef + innerExpr, Seq(outerCommonExprDef)) - val outerCommonExprName = "_common_expr_0" + val outerExpr = With.create((b + b, 1)) { case Seq(ref) => + ref * ref + innerExpr + } + val outerCommonExprName = "_common_expr_1" val plan = testRelation.select(outerExpr.as("col")) val rewrittenInnerExpr = (a + a).as(innerCommonExprName) @@ -116,12 +120,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("correlated nested WITH expression is not supported") { val b = testRelation.output.last - val outerCommonExprDef = CommonExpressionDef(b + b) + val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0)) val outerRef = new CommonExpressionRef(outerCommonExprDef) val a = testRelation.output.head // The inner expression definition references the outer expression - val commonExprDef1 = CommonExpressionDef(a + a + outerRef) + val commonExprDef1 = CommonExpressionDef(a + a + outerRef, CommonExpressionId(1)) val ref1 = new CommonExpressionRef(commonExprDef1) val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1)) @@ -139,9 +143,10 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in filter") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val plan = testRelation.where(With(ref < 10 && ref > 0, Seq(commonExprDef))) + val condition = With.create((a + a, 0)) { case Seq(ref) => + ref < 10 && ref > 0 + } + val plan = testRelation.where(condition) val commonExprName = "_common_expr_0" comparePlans( Optimizer.execute(plan), @@ -155,9 +160,9 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: only reference left child") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val condition = With.create((a + a, 0)) { case Seq(ref) => + ref < 10 && ref > 0 + } val plan = testRelation.join(testRelation2, condition = Some(condition)) val commonExprName = "_common_expr_0" comparePlans( @@ -172,9 +177,9 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: only reference right child") { val x = testRelation2.output.head - val commonExprDef = CommonExpressionDef(x + x) - val ref = new CommonExpressionRef(commonExprDef) - val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val condition = With.create((x + x, 0)) { case Seq(ref) => + ref < 10 && ref > 0 + } val plan = testRelation.join(testRelation2, condition = Some(condition)) val commonExprName = "_common_expr_0" comparePlans( @@ -192,9 +197,9 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: reference both children") { val a = testRelation.output.head val x = testRelation2.output.head - val commonExprDef = CommonExpressionDef(a + x) - val ref = new CommonExpressionRef(commonExprDef) - val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val condition = With.create((a + x, 0)) { case Seq(ref) => + ref < 10 && ref > 0 + } val plan = testRelation.join(testRelation2, condition = Some(condition)) comparePlans( Optimizer.execute(plan), @@ -209,15 +214,17 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression inside conditional expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef)))) + val expr = Coalesce(Seq(a, With.create((a + a, 0)) { case Seq(ref) => + ref * ref + })) val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a))) val plan = testRelation.select(expr.as("col")) // With in the conditional branches is always inlined. comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col"))) - val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a)) + val expr2 = Coalesce(Seq(With.create((a + a, 0)) { case Seq(ref) => + ref * ref + }, a)) val plan2 = testRelation.select(expr2.as("col")) val commonExprName = "_common_expr_0" // With in the always-evaluated branches can still be optimized. @@ -229,4 +236,85 @@ class RewriteWithExpressionSuite extends PlanTest { .analyze ) } + + test("WITH expression in grouping exprs") { + val a = testRelation.output.head + val expr1 = With.create((a + 1, 0)) { case Seq(ref) => + ref * ref + } + val expr2 = With.create((a + 1, 1)) { case Seq(ref) => + ref * ref + } + val expr3 = With.create((a + 1, 2)) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(expr1)( + (expr2 + 2).as("col1"), + count(expr3 - 3).as("col2") + ) + val commonExpr1Name = "_common_expr_0" + // Note that _common_expr_1 gets deduplicated by PullOutGroupingExpressions. + val commonExpr2Name = "_common_expr_2" + val groupingExprName = "_groupingexpression" + val countAlias = count(expr3 - 3).toString + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output :+ + ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*) + .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr2Name)): _*) + .groupBy($"$groupingExprName")( + $"$groupingExprName", + count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias) + ) + .select(($"$groupingExprName" + 2).as("col1"), $"`$countAlias`".as("col2")) + .analyze + ) + // Running CollapseProject after the rule cleans up the unnecessary projections. + comparePlans( + CollapseProject(Optimizer.execute(plan)), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output ++ Seq( + ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName), + (a + 1).as(commonExpr2Name)): _*) + .groupBy($"$groupingExprName")( + ($"$groupingExprName" + 2).as("col1"), + count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2") + ) + .analyze + ) + } + + test("WITH expression in aggregate exprs") { + val Seq(a, b) = testRelation.output + val expr1 = With.create((a + 1, 0)) { case Seq(ref) => + ref * ref + } + val expr2 = With.create((b + 2, 1)) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(a)( + (a + 3).as("col1"), + expr1.as("col2"), + max(expr2).as("col3") + ) + val commonExpr1Name = "_common_expr_0" + val commonExpr2Name = "_common_expr_1" + val maxAlias = max(expr2).toString + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*) + .groupBy(a)(a, max($"$commonExpr2Name" * $"$commonExpr2Name").as(maxAlias)) + .select(a, $"`$maxAlias`", (a + 1).as(commonExpr1Name)) + .select( + (a + 3).as("col1"), + ($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"), + $"`$maxAlias`".as("col3") + ) + .analyze + ) + } } From 33db2bdab7c8ecbc6e66a22302d7e7c966aedf02 Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Fri, 12 Apr 2024 18:42:48 -0700 Subject: [PATCH 2/8] fix compilation errors --- .../sql/catalyst/optimizer/RewriteWithExpression.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 282b9254c8a98..e29ec9f9b840b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, PlanHelper, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{ALIAS, COMMON_EXPR_REF, WITH_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} /** * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or @@ -40,15 +40,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) { case p @ PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child, limit) + groupingExpressions, aggregateExpressions, resultExpressions, child) if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => // For aggregates, separate computation of the aggregations themselves from the final // result by moving the final result computation into a projection above. This prevents // this rule from producing an invalid Aggregate operator. // TODO: the names of these aliases will become outdated after the rewrite - val aggExprs = aggregateExpressions.map(ae => Alias(ae, ae.toString)(ae.resultIds.head)) + val aggExprs = aggregateExpressions.map(ae => Alias(ae, ae.toString)(ae.resultId)) // Rewrite the projection and the aggregate separately and then piece them together. - val agg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child, limit) + val agg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) val rewrittenAgg = applyInternal(agg) val proj = Project(resultExpressions, rewrittenAgg) applyInternal(proj) From 2aad8418f1a48f2a4e442e8684ebaef655376cd6 Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Mon, 15 Apr 2024 14:57:35 -0700 Subject: [PATCH 3/8] addressed comments --- .../optimizer/RewriteWithExpression.scala | 21 ++++--- .../sql/catalyst/planning/patterns.scala | 1 + .../RewriteWithExpressionSuite.scala | 57 ++++++++++++++++--- 3 files changed, 65 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index e29ec9f9b840b..a90472f2c8222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -39,18 +39,25 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EX object RewriteWithExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) { + // For aggregates, separate the computation of the aggregations themselves from the final + // result by moving the final result computation into a projection above it. This prevents + // this rule from producing an invalid Aggregate operator. case p @ PhysicalAggregation( groupingExpressions, aggregateExpressions, resultExpressions, child) if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => - // For aggregates, separate computation of the aggregations themselves from the final - // result by moving the final result computation into a projection above. This prevents - // this rule from producing an invalid Aggregate operator. - // TODO: the names of these aliases will become outdated after the rewrite - val aggExprs = aggregateExpressions.map(ae => Alias(ae, ae.toString)(ae.resultId)) + // PhysicalAggregation returns aggregateExpressions as attribute references, which we change + // to aliases so that they can be referred to by resultExpressions. + val aggExprs = aggregateExpressions.map( + ae => Alias(ae, "_aggregateexpression")(ae.resultId)) + val aggExprIds = aggExprs.map(_.exprId).toSet + val resExprs = resultExpressions.map(_.transform { + case a: AttributeReference if aggExprIds.contains(a.exprId) => + a.withName("_aggregateexpression") + }.asInstanceOf[NamedExpression]) // Rewrite the projection and the aggregate separately and then piece them together. val agg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) val rewrittenAgg = applyInternal(agg) - val proj = Project(resultExpressions, rewrittenAgg) + val proj = Project(resExprs, rewrittenAgg) applyInternal(proj) case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => applyInternal(p) @@ -62,7 +69,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] { var newPlan: LogicalPlan = p.mapExpressions { expr => rewriteWithExprAndInputPlans(expr, inputPlans) } - newPlan = newPlan.withNewChildren(inputPlans) + newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq) // Since we add extra Projects with extra columns to pre-evaluate the common expressions, // the current operator may have extra columns if it inherits the output columns from its // child, and we need to project away the extra columns to keep the plan schema unchanged. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index e48b44a603ad7..62ac55b34e026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE_EXPRESSION import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index 5bb8cf1dd3b8a..5c40df76bfa4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -256,7 +256,7 @@ class RewriteWithExpressionSuite extends PlanTest { // Note that _common_expr_1 gets deduplicated by PullOutGroupingExpressions. val commonExpr2Name = "_common_expr_2" val groupingExprName = "_groupingexpression" - val countAlias = count(expr3 - 3).toString + val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), testRelation @@ -266,9 +266,9 @@ class RewriteWithExpressionSuite extends PlanTest { .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr2Name)): _*) .groupBy($"$groupingExprName")( $"$groupingExprName", - count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(countAlias) + count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(aggExprName) ) - .select(($"$groupingExprName" + 2).as("col1"), $"`$countAlias`".as("col2")) + .select(($"$groupingExprName" + 2).as("col1"), $"`$aggExprName`".as("col2")) .analyze ) // Running CollapseProject after the rule cleans up the unnecessary projections. @@ -302,19 +302,62 @@ class RewriteWithExpressionSuite extends PlanTest { ) val commonExpr1Name = "_common_expr_0" val commonExpr2Name = "_common_expr_1" - val maxAlias = max(expr2).toString + val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), testRelation .select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*) - .groupBy(a)(a, max($"$commonExpr2Name" * $"$commonExpr2Name").as(maxAlias)) - .select(a, $"`$maxAlias`", (a + 1).as(commonExpr1Name)) + .groupBy(a)(a, max($"$commonExpr2Name" * $"$commonExpr2Name").as(aggExprName)) + .select(a, $"`$aggExprName`", (a + 1).as(commonExpr1Name)) .select( (a + 3).as("col1"), ($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"), - $"`$maxAlias`".as("col3") + $"`$aggExprName`".as("col3") ) .analyze ) } + + test("WITH common expression is aggregate function") { + val a = testRelation.output.head + val expr = With.create((count(a - 1), 0)) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(a)( + (a - 1).as("col1"), + expr.as("col2"), + ) + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .groupBy(a)(a, count(a - 1).as(aggExprName)) + .select( + (a - 1).as("col1"), + ($"$aggExprName" * $"$aggExprName").as("col2") + ) + .analyze + ) + } + + test("WITH expression is aggregate function") { + val a = testRelation.output.head + val expr = With.create((a - 1, 0)) { case Seq(ref) => + sum(ref * ref) + } + val plan = testRelation.groupBy(a)( + (a - 1).as("col1"), + expr.as("col2") + ) + val commonExpr1Name = "_common_expr_0" + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (a - 1).as(commonExpr1Name): _*) + .groupBy(a)(a, sum($"$commonExpr1Name" * $"$commonExpr1Name").as(aggExprName)) + .select((a - 1).as("col1"), $"$aggExprName".as("col2")) + .analyze + ) + } } From 7815ab25eddfc2fedee133525bfffce79b240cbd Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Mon, 15 Apr 2024 15:21:24 -0700 Subject: [PATCH 4/8] remove unused import --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 62ac55b34e026..e48b44a603ad7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE_EXPRESSION import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE import org.apache.spark.sql.errors.QueryCompilationErrors From f503b720d99cff4025ed64ff0f4d13c474ba8d4d Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Tue, 16 Apr 2024 10:59:01 -0700 Subject: [PATCH 5/8] addressed comments --- .../spark/sql/catalyst/expressions/With.scala | 11 -- .../optimizer/RewriteWithExpression.scala | 4 + .../RewriteWithExpressionSuite.scala | 100 +++++++++--------- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index fa72a6fc121d2..2745b663639f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -88,17 +88,6 @@ object With { val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_)) With(replaced(commonExprRefs), commonExprDefs) } - - /** - * Used for testing to specify the exact common expression ID for each common expression. - */ - def create(commonExprs: (Expression, Long)*)(replaced: Seq[Expression] => Expression): With = { - val commonExprDefs = commonExprs.map { case (expr, exprId) => - CommonExpressionDef(expr, CommonExpressionId(exprId)) - } - val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_)) - With(replaced(commonExprRefs), commonExprDefs) - } } case class CommonExpressionId(id: Long = CommonExpressionId.newId, canonicalized: Boolean = false) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index a90472f2c8222..7f947b379f967 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -45,6 +45,10 @@ object RewriteWithExpression extends Rule[LogicalPlan] { case p @ PhysicalAggregation( groupingExpressions, aggregateExpressions, resultExpressions, child) if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => + // There should not be dangling common expression references in the aggregate expressions. + // This can happen if a With is created with an aggregate function in its child. + assert(!aggregateExpressions.exists(ae => + !ae.containsPattern(WITH_EXPRESSION) && ae.containsPattern(COMMON_EXPR_REF))) // PhysicalAggregation returns aggregateExpressions as attribute references, which we change // to aliases so that they can be referred to by resultExpressions. val aggExprs = aggregateExpressions.map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index 5c40df76bfa4f..82771ca726dfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -39,7 +39,7 @@ class RewriteWithExpressionSuite extends PlanTest { test("simple common expression") { val a = testRelation.output.head - val expr = With.create((a, 0)) { case Seq(ref) => + val expr = With(a) { case Seq(ref) => ref + ref } val plan = testRelation.select(expr.as("col")) @@ -48,11 +48,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("non-cheap common expression") { val a = testRelation.output.head - val expr = With.create((a + a, 0)) { case Seq(ref) => + val expr = With(a + a) { case Seq(ref) => ref * ref } val plan = testRelation.select(expr.as("col")) - val commonExprName = "_common_expr_0" + val commonExprId = expr.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -64,16 +65,18 @@ class RewriteWithExpressionSuite extends PlanTest { test("nested WITH expression in the definition expression") { val a = testRelation.output.head - val innerExpr = With.create((a + a, 0)) { case Seq(ref) => + val innerExpr = With(a + a) { case Seq(ref) => ref + ref } - val innerCommonExprName = "_common_expr_0" + val innerCommonExprId = innerExpr.defs.head.id.id + val innerCommonExprName = s"_common_expr_$innerCommonExprId" val b = testRelation.output.last - val outerExpr = With.create((innerExpr + b, 1)) { case Seq(ref) => + val outerExpr = With(innerExpr + b) { case Seq(ref) => ref * ref } - val outerCommonExprName = "_common_expr_1" + val outerCommonExprId = outerExpr.defs.head.id.id + val outerCommonExprName = s"_common_expr_$outerCommonExprId" val plan = testRelation.select(outerExpr.as("col")) val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) @@ -92,16 +95,18 @@ class RewriteWithExpressionSuite extends PlanTest { test("nested WITH expression in the main expression") { val a = testRelation.output.head - val innerExpr = With.create((a + a, 0)) { case Seq(ref) => + val innerExpr = With(a + a) { case Seq(ref) => ref + ref } - val innerCommonExprName = "_common_expr_0" + val innerCommonExprId = innerExpr.defs.head.id.id + val innerCommonExprName = s"_common_expr_$innerCommonExprId" val b = testRelation.output.last - val outerExpr = With.create((b + b, 1)) { case Seq(ref) => + val outerExpr = With(b + b) { case Seq(ref) => ref * ref + innerExpr } - val outerCommonExprName = "_common_expr_1" + val outerCommonExprId = outerExpr.defs.head.id.id + val outerCommonExprName = s"_common_expr_$outerCommonExprId" val plan = testRelation.select(outerExpr.as("col")) val rewrittenInnerExpr = (a + a).as(innerCommonExprName) @@ -143,11 +148,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in filter") { val a = testRelation.output.head - val condition = With.create((a + a, 0)) { case Seq(ref) => + val condition = With(a + a) { case Seq(ref) => ref < 10 && ref > 0 } val plan = testRelation.where(condition) - val commonExprName = "_common_expr_0" + val commonExprId = condition.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -160,11 +166,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: only reference left child") { val a = testRelation.output.head - val condition = With.create((a + a, 0)) { case Seq(ref) => + val condition = With(a + a) { case Seq(ref) => ref < 10 && ref > 0 } val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprName = "_common_expr_0" + val commonExprId = condition.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -177,11 +184,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: only reference right child") { val x = testRelation2.output.head - val condition = With.create((x + x, 0)) { case Seq(ref) => + val condition = With(x + x) { case Seq(ref) => ref < 10 && ref > 0 } val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprName = "_common_expr_0" + val commonExprId = condition.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -197,7 +205,7 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: reference both children") { val a = testRelation.output.head val x = testRelation2.output.head - val condition = With.create((a + x, 0)) { case Seq(ref) => + val condition = With(a + x) { case Seq(ref) => ref < 10 && ref > 0 } val plan = testRelation.join(testRelation2, condition = Some(condition)) @@ -214,7 +222,7 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression inside conditional expression") { val a = testRelation.output.head - val expr = Coalesce(Seq(a, With.create((a + a, 0)) { case Seq(ref) => + val expr = Coalesce(Seq(a, With(a + a) { case Seq(ref) => ref * ref })) val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a))) @@ -222,11 +230,12 @@ class RewriteWithExpressionSuite extends PlanTest { // With in the conditional branches is always inlined. comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col"))) - val expr2 = Coalesce(Seq(With.create((a + a, 0)) { case Seq(ref) => + val expr2 = Coalesce(Seq(With(a + a) { case Seq(ref) => ref * ref }, a)) val plan2 = testRelation.select(expr2.as("col")) - val commonExprName = "_common_expr_0" + val commonExprId = expr2.children.head.asInstanceOf[With].defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" // With in the always-evaluated branches can still be optimized. comparePlans( Optimizer.execute(plan2), @@ -239,22 +248,24 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in grouping exprs") { val a = testRelation.output.head - val expr1 = With.create((a + 1, 0)) { case Seq(ref) => + val expr1 = With(a + 1) { case Seq(ref) => ref * ref } - val expr2 = With.create((a + 1, 1)) { case Seq(ref) => + val expr2 = With(a + 1) { case Seq(ref) => ref * ref } - val expr3 = With.create((a + 1, 2)) { case Seq(ref) => + val expr3 = With(a + 1) { case Seq(ref) => ref * ref } val plan = testRelation.groupBy(expr1)( (expr2 + 2).as("col1"), count(expr3 - 3).as("col2") ) - val commonExpr1Name = "_common_expr_0" - // Note that _common_expr_1 gets deduplicated by PullOutGroupingExpressions. - val commonExpr2Name = "_common_expr_2" + val commonExpr1Id = expr1.defs.head.id.id + val commonExpr1Name = s"_common_expr_$commonExpr1Id" + // Note that the common expression in expr2 gets de-duplicated by PullOutGroupingExpressions. + val commonExpr3Id = expr3.defs.head.id.id + val commonExpr3Name = s"_common_expr_$commonExpr3Id" val groupingExprName = "_groupingexpression" val aggExprName = "_aggregateexpression" comparePlans( @@ -263,10 +274,10 @@ class RewriteWithExpressionSuite extends PlanTest { .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) .select(testRelation.output :+ ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*) - .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr2Name)): _*) + .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr3Name)): _*) .groupBy($"$groupingExprName")( $"$groupingExprName", - count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as(aggExprName) + count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as(aggExprName) ) .select(($"$groupingExprName" + 2).as("col1"), $"`$aggExprName`".as("col2")) .analyze @@ -278,10 +289,10 @@ class RewriteWithExpressionSuite extends PlanTest { .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) .select(testRelation.output ++ Seq( ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName), - (a + 1).as(commonExpr2Name)): _*) + (a + 1).as(commonExpr3Name)): _*) .groupBy($"$groupingExprName")( ($"$groupingExprName" + 2).as("col1"), - count($"$commonExpr2Name" * $"$commonExpr2Name" - 3).as("col2") + count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as("col2") ) .analyze ) @@ -289,10 +300,10 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in aggregate exprs") { val Seq(a, b) = testRelation.output - val expr1 = With.create((a + 1, 0)) { case Seq(ref) => + val expr1 = With(a + 1) { case Seq(ref) => ref * ref } - val expr2 = With.create((b + 2, 1)) { case Seq(ref) => + val expr2 = With(b + 2) { case Seq(ref) => ref * ref } val plan = testRelation.groupBy(a)( @@ -300,8 +311,10 @@ class RewriteWithExpressionSuite extends PlanTest { expr1.as("col2"), max(expr2).as("col3") ) - val commonExpr1Name = "_common_expr_0" - val commonExpr2Name = "_common_expr_1" + val commonExpr1Id = expr1.defs.head.id.id + val commonExpr1Name = s"_common_expr_$commonExpr1Id" + val commonExpr2Id = expr2.defs.head.id.id + val commonExpr2Name = s"_common_expr_$commonExpr2Id" val aggExprName = "_aggregateexpression" comparePlans( Optimizer.execute(plan), @@ -320,7 +333,7 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH common expression is aggregate function") { val a = testRelation.output.head - val expr = With.create((count(a - 1), 0)) { case Seq(ref) => + val expr = With(count(a - 1)) { case Seq(ref) => ref * ref } val plan = testRelation.groupBy(a)( @@ -340,24 +353,15 @@ class RewriteWithExpressionSuite extends PlanTest { ) } - test("WITH expression is aggregate function") { + test("aggregate functions in child of WITH expression is not supported") { val a = testRelation.output.head - val expr = With.create((a - 1, 0)) { case Seq(ref) => + val expr = With(a - 1) { case Seq(ref) => sum(ref * ref) } val plan = testRelation.groupBy(a)( (a - 1).as("col1"), expr.as("col2") ) - val commonExpr1Name = "_common_expr_0" - val aggExprName = "_aggregateexpression" - comparePlans( - Optimizer.execute(plan), - testRelation - .select(testRelation.output :+ (a - 1).as(commonExpr1Name): _*) - .groupBy(a)(a, sum($"$commonExpr1Name" * $"$commonExpr1Name").as(aggExprName)) - .select((a - 1).as("col1"), $"$aggExprName".as("col2")) - .analyze - ) + intercept[java.lang.AssertionError](Optimizer.execute(plan)) } } From 287b0d8a80d3df10ef5f79d45f4c65e4a3a0900e Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Tue, 16 Apr 2024 14:00:29 -0700 Subject: [PATCH 6/8] scalastyle --- .../sql/catalyst/optimizer/RewriteWithExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index 82771ca726dfc..b7057bf95dce0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -338,7 +338,7 @@ class RewriteWithExpressionSuite extends PlanTest { } val plan = testRelation.groupBy(a)( (a - 1).as("col1"), - expr.as("col2"), + expr.as("col2") ) val aggExprName = "_aggregateexpression" comparePlans( From cce8e86a5fb23d2f7a4624ff7504b0927d926da7 Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Wed, 17 Apr 2024 13:27:55 -0700 Subject: [PATCH 7/8] make alias names consistent for tests --- .../sql/connect/ProtoToParsedPlanTestSuite.scala | 1 + .../catalyst/optimizer/RewriteWithExpression.scala | 10 ++++++++-- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index cc9decb4c98bc..d404779d7a92f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -126,6 +126,7 @@ class ProtoToParsedPlanTestSuite Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key, "org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin") .set(org.apache.spark.sql.internal.SQLConf.ANSI_ENABLED.key, false.toString) + .set(org.apache.spark.sql.internal.SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key, false.toString) } protected val suiteBaseResourcePath = commonResourcePath.resolve("query-tests") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 7f947b379f967..71e0ac9db787f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, PlanHelper, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} +import org.apache.spark.sql.internal.SQLConf /** * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or @@ -98,7 +99,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] { val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression] val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias]) - defs.foreach { case CommonExpressionDef(child, id) => + defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => if (child.containsPattern(COMMON_EXPR_REF)) { throw SparkException.internalError( "Common expression definition cannot reference other Common expression definitions") @@ -125,7 +126,12 @@ object RewriteWithExpression extends Rule[LogicalPlan] { // if it's ref count is 1. refToExpr(id) = child } else { - val alias = Alias(child, s"_common_expr_${id.id}")() + val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) { + s"_common_expr_${id.id}" + } else { + s"_common_expr_$index" + } + val alias = Alias(child, aliasName)() val fakeProj = Project(Seq(alias), inputPlans(childProjectionIndex)) if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) { // We have to inline the common expression if it cannot be put in a Project. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c8a5d997da7d2..5bc80f9b28b11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3429,6 +3429,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val USE_COMMON_EXPR_ID_FOR_ALIAS = + buildConf("spark.sql.useCommonExprIdForAlias") + .internal() + .doc("When true, use the common expression ID for the alias when rewriting With " + + "expressions. Otherwise, use the index of the common expression definition. When true " + + "this avoids duplicate alias names, but is helpful to set to false for testing to ensure" + + "that alias names are consistent.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES = buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues") .internal() From c68ab9cc005a1dd3bc51e5ed23bf0d06ac8edbdb Mon Sep 17 00:00:00 2001 From: Kelvin Jiang Date: Wed, 17 Apr 2024 13:48:36 -0700 Subject: [PATCH 8/8] move assertion --- .../explain-results/function_count_if.explain | 7 ++++--- .../spark/sql/catalyst/expressions/With.scala | 6 +++++- .../optimizer/RewriteWithExpression.scala | 4 ---- .../optimizer/RewriteWithExpressionSuite.scala | 16 +++++++++------- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain index f2ada15eccb7d..a9fd2eeb669aa 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain @@ -1,3 +1,4 @@ -Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L] -+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] - +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] +Project [_aggregateexpression#0L AS count_if((a > 0))#0L] ++- Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS _aggregateexpression#0L] + +- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] + +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index 2745b663639f8..14deedd9c70fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} import org.apache.spark.sql.types.DataType /** @@ -27,6 +27,10 @@ import org.apache.spark.sql.types.DataType */ case class With(child: Expression, defs: Seq[CommonExpressionDef]) extends Expression with Unevaluable { + // We do not allow With to be created with an AggregateExpression in the child, as this would + // create a dangling CommonExpressionRef after rewriting it in RewriteWithExpression. + assert(!child.containsPattern(AGGREGATE_EXPRESSION)) + override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 71e0ac9db787f..393a66f7c1e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -46,10 +46,6 @@ object RewriteWithExpression extends Rule[LogicalPlan] { case p @ PhysicalAggregation( groupingExpressions, aggregateExpressions, resultExpressions, child) if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => - // There should not be dangling common expression references in the aggregate expressions. - // This can happen if a With is created with an aggregate function in its child. - assert(!aggregateExpressions.exists(ae => - !ae.containsPattern(WITH_EXPRESSION) && ae.containsPattern(COMMON_EXPR_REF))) // PhysicalAggregation returns aggregateExpressions as attribute references, which we change // to aliases so that they can be referred to by resultExpressions. val aggExprs = aggregateExpressions.map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index b7057bf95dce0..d482b18d93316 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -355,13 +355,15 @@ class RewriteWithExpressionSuite extends PlanTest { test("aggregate functions in child of WITH expression is not supported") { val a = testRelation.output.head - val expr = With(a - 1) { case Seq(ref) => - sum(ref * ref) + intercept[java.lang.AssertionError] { + val expr = With(a - 1) { case Seq(ref) => + sum(ref * ref) + } + val plan = testRelation.groupBy(a)( + (a - 1).as("col1"), + expr.as("col2") + ) + Optimizer.execute(plan) } - val plan = testRelation.groupBy(a)( - (a - 1).as("col1"), - expr.as("col2") - ) - intercept[java.lang.AssertionError](Optimizer.execute(plan)) } }