diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index f506bdddc16b5..f1e053f7fb2a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} /** @@ -40,4 +40,28 @@ trait BaseAggregateExec extends UnaryExecNode { |${ExplainUtils.generateFieldString("Results", resultExpressions)} |""".stripMargin } + + protected def inputAttributes: Seq[Attribute] = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the + // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the + // `DeclarativeAggregate` will be replaced by a new instance with new + // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate + // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use + // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, + // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = aggregateExpressions + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) + .flatMap(_.inputAggBufferAttributes) + child.output.dropRight(aggAttrs.length) ++ aggAttrs + } else { + child.output + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8af17ed0e1639..9c07ea10a87e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -129,7 +129,7 @@ case class HashAggregateExec( resultExpressions, (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), - child.output, + inputAttributes, iter, testFallbackStartsAt, numOutputRows, @@ -334,7 +334,7 @@ case class HashAggregateExec( private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes // To individually generate code for each aggregate function, an element in `updateExprs` holds // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => @@ -931,7 +931,7 @@ case class HashAggregateExec( } } - val inputAttr = aggregateBufferAttributes ++ child.output + val inputAttr = aggregateBufferAttributes ++ inputAttributes // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fb58eb2cc8ba..f1c0719ff8948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -123,7 +123,7 @@ case class ObjectHashAggregateExec( resultExpressions, (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), - child.output, + inputAttributes, iter, fallbackCountThreshold, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 9610eab82c7cb..ba0c3517a1a14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -88,7 +88,7 @@ case class SortAggregateExec( val outputIter = new SortBasedAggregationIterator( partIndex, groupingExpressions, - child.output, + inputAttributes, iter, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4edf3a5d39fdd..2293d4ae61aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -973,4 +973,43 @@ class DataFrameAggregateSuite extends QueryTest assert(error.message.contains("function count_if requires boolean type")) } } + + Seq(true, false).foreach { value => + test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) { + withTempView("t1", "t2") { + sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") + sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + + // test without grouping keys + checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), + Row(4) :: Nil) + + // test with grouping keys + checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + + "t2 group by c"), Row(3, 4) :: Nil) + + // test with distinct + checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil) + + // test subquery with agg + checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) + + // test SortAggregateExec + var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: SortAggregateExec => true }.isDefined) + checkAnswer(df, Row("str1") :: Nil) + + // test ObjectHashAggregateExec + df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: ObjectHashAggregateExec => true }.isDefined) + checkAnswer(df, Row(Array(4), 4) :: Nil) + } + } + } + } }