From ba75e36013984c7021ebc0e2e1291f2be5d99528 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 11 May 2020 14:58:52 +0800 Subject: [PATCH 01/11] fix --- .../expressions/aggregate/interfaces.scala | 20 +++++++++++++++++-- .../spark/sql/DataFrameAggregateSuite.scala | 9 +++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 222ad6fab19e..f8ebb8cc0b02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -398,8 +399,20 @@ abstract class DeclarativeAggregate /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = { + // SPARK-31620: inputAggBufferAttributes from a partial agg can be referenced by a final agg + // in order to merge agg values. However, in case of an aggregate function contains a subquery, + // the aggregate function will be transformed to new copy during `PlanSubqueries` and lost + // original attributes because `TreeNode` does not preserve "lazy val" during `makeCopy`. As a + // result, the final agg could fail to resolve references through partial agg. So we use the + // tag to save the original attributes to let the new copy node share the same attributes with + // old node. + getTagValue(inputAggBufferAttributeTag).getOrElse { + val attrs = aggBufferAttributes.map(_.newInstance()) + setTagValue(inputAggBufferAttributeTag, attrs) + attrs + } + } /** * A helper class for representing an attribute used in merging two @@ -415,6 +428,9 @@ abstract class DeclarativeAggregate /** Represents this attribute at the input buffer side (the data value is read-only). */ def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } + + private val inputAggBufferAttributeTag = + TreeNodeTag[Seq[AttributeReference]]("inputAggBufferAttributes") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4edf3a5d39fd..67abca49c56f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -973,4 +973,13 @@ class DataFrameAggregateSuite extends QueryTest assert(error.message.contains("function count_if requires boolean type")) } } + + test("SPARK-31620: agg with subquery") { + withTempView("t1", "t2") { + sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") + sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), + Row(4) :: Nil) + } + } } From 4dcc1b1db3a0a49bd3889aa2a3b3e33a46f7a535 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 11 May 2020 15:16:41 +0800 Subject: [PATCH 02/11] update --- .../catalyst/expressions/aggregate/interfaces.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index f8ebb8cc0b02..6e1a4d71357d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -402,11 +402,11 @@ abstract class DeclarativeAggregate final lazy val inputAggBufferAttributes: Seq[AttributeReference] = { // SPARK-31620: inputAggBufferAttributes from a partial agg can be referenced by a final agg // in order to merge agg values. However, in case of an aggregate function contains a subquery, - // the aggregate function will be transformed to new copy during `PlanSubqueries` and lost - // original attributes because `TreeNode` does not preserve "lazy val" during `makeCopy`. As a - // result, the final agg could fail to resolve references through partial agg. So we use the - // tag to save the original attributes to let the new copy node share the same attributes with - // old node. + // the aggregate function will be transformed to a new copied node during `PlanSubqueries` and + // lost original attributes because `TreeNode` does not preserve "lazy val" during `makeCopy`. + // As a result, the final agg could fail to resolve references through partial agg. So we use + // the tag to save the original attributes to let the new copy node share the same attributes + // with old node. getTagValue(inputAggBufferAttributeTag).getOrElse { val attrs = aggBufferAttributes.map(_.newInstance()) setTagValue(inputAggBufferAttributeTag, attrs) From bc1b7e5c6ee2a34d4e83ac333719a5192db87c9b Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 11 May 2020 15:19:01 +0800 Subject: [PATCH 03/11] update2 --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 6e1a4d71357d..383167ad12db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -405,7 +405,7 @@ abstract class DeclarativeAggregate // the aggregate function will be transformed to a new copied node during `PlanSubqueries` and // lost original attributes because `TreeNode` does not preserve "lazy val" during `makeCopy`. // As a result, the final agg could fail to resolve references through partial agg. So we use - // the tag to save the original attributes to let the new copy node share the same attributes + // the tag to save the original attributes to let the new copied node share the same attributes // with old node. getTagValue(inputAggBufferAttributeTag).getOrElse { val attrs = aggBufferAttributes.map(_.newInstance()) From cbd891d11c24b6afe0f8e886c9cd09bb450e2405 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 11 May 2020 21:36:16 +0800 Subject: [PATCH 04/11] update3 --- .../expressions/aggregate/interfaces.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 383167ad12db..d90282dfd33f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -400,13 +400,15 @@ abstract class DeclarativeAggregate final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) final lazy val inputAggBufferAttributes: Seq[AttributeReference] = { - // SPARK-31620: inputAggBufferAttributes from a partial agg can be referenced by a final agg - // in order to merge agg values. However, in case of an aggregate function contains a subquery, - // the aggregate function will be transformed to a new copied node during `PlanSubqueries` and - // lost original attributes because `TreeNode` does not preserve "lazy val" during `makeCopy`. - // As a result, the final agg could fail to resolve references through partial agg. So we use - // the tag to save the original attributes to let the new copied node share the same attributes - // with old node. + // SPARK-31620: inputAggBufferAttributes from a partial aggregate can be referenced by a final + // aggregate in order to merge aggregate values. However, in case of an aggregate function + // contains a subquery, the aggregate function will be transformed to a new copied node during + // `PlanSubqueries` and lost original attributes because `TreeNode` does not preserve them + // during `makeCopy`. As a result, the final aggregate could fail to resolve references through + // partial aggregate's output. So, we use the tag to save the original attributes to let the + // new copied node share the same attributes with old node. Note, we don't save other attributes + // within an aggregate function and ImperativeAggregate's inputAggBufferAttributes because they + // will not be referenced out of the aggregate function itself. getTagValue(inputAggBufferAttributeTag).getOrElse { val attrs = aggBufferAttributes.map(_.newInstance()) setTagValue(inputAggBufferAttributeTag, attrs) From 882467b769c243c6697e77240f2f312862f9c531 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 13 May 2020 17:17:14 +0800 Subject: [PATCH 05/11] update --- .../expressions/aggregate/interfaces.scala | 21 ++---------- .../aggregate/HashAggregateExec.scala | 32 ++++++++++++++++--- .../spark/sql/DataFrameAggregateSuite.scala | 20 ++++++++---- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d90282dfd33f..aa29017aafe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -399,22 +399,8 @@ abstract class DeclarativeAggregate /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = { - // SPARK-31620: inputAggBufferAttributes from a partial aggregate can be referenced by a final - // aggregate in order to merge aggregate values. However, in case of an aggregate function - // contains a subquery, the aggregate function will be transformed to a new copied node during - // `PlanSubqueries` and lost original attributes because `TreeNode` does not preserve them - // during `makeCopy`. As a result, the final aggregate could fail to resolve references through - // partial aggregate's output. So, we use the tag to save the original attributes to let the - // new copied node share the same attributes with old node. Note, we don't save other attributes - // within an aggregate function and ImperativeAggregate's inputAggBufferAttributes because they - // will not be referenced out of the aggregate function itself. - getTagValue(inputAggBufferAttributeTag).getOrElse { - val attrs = aggBufferAttributes.map(_.newInstance()) - setTagValue(inputAggBufferAttributeTag, attrs) - attrs - } - } + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) /** * A helper class for representing an attribute used in merging two @@ -430,9 +416,6 @@ abstract class DeclarativeAggregate /** Represents this attribute at the input buffer side (the data value is read-only). */ def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } - - private val inputAggBufferAttributeTag = - TreeNodeTag[Seq[AttributeReference]]("inputAggBufferAttributes") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8af17ed0e163..91e1f27d9999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -129,7 +129,7 @@ case class HashAggregateExec( resultExpressions, (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), - child.output, + inputAttributes, iter, testFallbackStartsAt, numOutputRows, @@ -331,10 +331,32 @@ case class HashAggregateExec( } } + private def inputAttributes: Seq[Attribute] = { + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the + // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the + // `DeclarativeAggregate` will be replaced by a new instance with new + // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate + // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use + // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, + // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = aggregateExpressions.map(_.aggregateFunction) + .flatMap(_.inputAggBufferAttributes) + val distinctAttrs = child.output.filterNot( + a => (groupingAttributes ++ aggAttrs).exists(_.name == a.name)) + // the order is consistent with `AggUtils.planAggregateWithOneDistinct` + groupingAttributes ++ distinctAttrs ++ aggAttrs + } else { + child.output + } + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes // To individually generate code for each aggregate function, an element in `updateExprs` holds // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => @@ -848,9 +870,9 @@ case class HashAggregateExec( private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( - ctx, bindReferences[Expression](groupingExpressions, child.output)) + ctx, bindReferences[Expression](groupingExpressions, inputAttributes)) val fastRowKeys = ctx.generateExpressions( - bindReferences[Expression](groupingExpressions, child.output)) + bindReferences[Expression](groupingExpressions, inputAttributes)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") @@ -931,7 +953,7 @@ case class HashAggregateExec( } } - val inputAttr = aggregateBufferAttributes ++ child.output + val inputAttr = aggregateBufferAttributes ++ inputAttributes // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 67abca49c56f..1653d21e9942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -974,12 +974,20 @@ class DataFrameAggregateSuite extends QueryTest } } - test("SPARK-31620: agg with subquery") { - withTempView("t1", "t2") { - sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") - sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") - checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), - Row(4) :: Nil) + Seq(true, false).foreach { value => + test(s"SPARK-31620: agg with subquery (codegen = $value)") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) { + withTempView("t1", "t2") { + sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") + sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + // test without grouping keys + checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), + Row(4) :: Nil) + // test with grouping keys + checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + + "t2 group by c"), Row(3, 4) :: Nil) + } + } } } } From 9a0b7883176b3790cb6c24a70fada4e21b1cc6e5 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 13 May 2020 17:21:03 +0800 Subject: [PATCH 06/11] fix --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index aa29017aafe5..222ad6fab19e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ From 3fd39c5721a2b2b13162b48dbee82081ffacb8ab Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 13 May 2020 18:14:35 +0800 Subject: [PATCH 07/11] address comment --- .../sql/execution/aggregate/HashAggregateExec.scala | 12 +++++------- .../apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 91e1f27d9999..fd692d7aa076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -342,12 +342,10 @@ case class HashAggregateExec( // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. - val aggAttrs = aggregateExpressions.map(_.aggregateFunction) + val aggAttrs = aggregateExpressions + .filter(a => a.mode == Final || !a.isDistinct).map(_.aggregateFunction) .flatMap(_.inputAggBufferAttributes) - val distinctAttrs = child.output.filterNot( - a => (groupingAttributes ++ aggAttrs).exists(_.name == a.name)) - // the order is consistent with `AggUtils.planAggregateWithOneDistinct` - groupingAttributes ++ distinctAttrs ++ aggAttrs + child.output.dropRight(aggAttrs.length) ++ aggAttrs } else { child.output } @@ -870,9 +868,9 @@ case class HashAggregateExec( private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( - ctx, bindReferences[Expression](groupingExpressions, inputAttributes)) + ctx, bindReferences[Expression](groupingExpressions, child.output)) val fastRowKeys = ctx.generateExpressions( - bindReferences[Expression](groupingExpressions, inputAttributes)) + bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 1653d21e9942..66f08523479b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -975,7 +975,7 @@ class DataFrameAggregateSuite extends QueryTest } Seq(true, false).foreach { value => - test(s"SPARK-31620: agg with subquery (codegen = $value)") { + test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) { withTempView("t1", "t2") { sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") From 64d13fc32d4bc6d460a84f100d3921c562baa7d8 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 14 May 2020 11:02:33 +0800 Subject: [PATCH 08/11] upate --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index fd692d7aa076..0b9c2ae66509 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -343,7 +343,10 @@ case class HashAggregateExec( // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. val aggAttrs = aggregateExpressions - .filter(a => a.mode == Final || !a.isDistinct).map(_.aggregateFunction) + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) .flatMap(_.inputAggBufferAttributes) child.output.dropRight(aggAttrs.length) ++ aggAttrs } else { From 89ae4bf621fd249f73305b30e7a8d158731d5178 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 14 May 2020 11:29:05 +0800 Subject: [PATCH 09/11] improve test --- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 66f08523479b..4ae6baf794bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -980,12 +980,22 @@ class DataFrameAggregateSuite extends QueryTest withTempView("t1", "t2") { sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + // test without grouping keys checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), Row(4) :: Nil) + // test with grouping keys checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + "t2 group by c"), Row(3, 4) :: Nil) + + // test with distinct + checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil) + + // test subquery with agg + checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) } } } From 8aecd576dde246f71ca9add2ed65c8c0592c0674 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 14 May 2020 16:49:40 +0800 Subject: [PATCH 10/11] fix sort/object agg as well --- .../aggregate/BaseAggregateExec.scala | 26 ++++++++++++++++++- .../aggregate/HashAggregateExec.scala | 23 ---------------- .../aggregate/ObjectHashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 8 ++++++ 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index f506bdddc16b..f1e053f7fb2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} /** @@ -40,4 +40,28 @@ trait BaseAggregateExec extends UnaryExecNode { |${ExplainUtils.generateFieldString("Results", resultExpressions)} |""".stripMargin } + + protected def inputAttributes: Seq[Attribute] = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the + // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the + // `DeclarativeAggregate` will be replaced by a new instance with new + // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate + // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use + // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, + // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = aggregateExpressions + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) + .flatMap(_.inputAggBufferAttributes) + child.output.dropRight(aggAttrs.length) ++ aggAttrs + } else { + child.output + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 0b9c2ae66509..9c07ea10a87e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -331,29 +331,6 @@ case class HashAggregateExec( } } - private def inputAttributes: Seq[Attribute] = { - if (modes.contains(Final) || modes.contains(PartialMerge)) { - // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's - // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the - // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the - // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the - // `DeclarativeAggregate` will be replaced by a new instance with new - // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate - // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use - // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, - // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. - val aggAttrs = aggregateExpressions - // there're exactly four cases needs `inputAggBufferAttributes` from child according to the - // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, - // Partial -> PartialMerge, PartialMerge -> PartialMerge. - .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) - .flatMap(_.inputAggBufferAttributes) - child.output.dropRight(aggAttrs.length) ++ aggAttrs - } else { - child.output - } - } - private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fb58eb2cc8b..f1c0719ff894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -123,7 +123,7 @@ case class ObjectHashAggregateExec( resultExpressions, (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), - child.output, + inputAttributes, iter, fallbackCountThreshold, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 9610eab82c7c..ba0c3517a1a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -88,7 +88,7 @@ case class SortAggregateExec( val outputIter = new SortBasedAggregationIterator( partIndex, groupingExpressions, - child.output, + inputAttributes, iter, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4ae6baf794bd..28e214a6cbf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -996,6 +996,14 @@ class DataFrameAggregateSuite extends QueryTest // test subquery with agg checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) + + // test SortAggregateExec + checkAnswer(sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2"), + Row("str1") :: Nil) + + // test ObjectHashAggregateExec + checkAnswer(sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum" + + " from t2"), Row(Array(4), 4) :: Nil) } } } From 493157a3b97616d221ec2b5ddf1a21cdf9a1a3f4 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 14 May 2020 21:41:37 +0800 Subject: [PATCH 11/11] check operator --- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 28e214a6cbf0..2293d4ae61af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -998,12 +998,16 @@ class DataFrameAggregateSuite extends QueryTest " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) // test SortAggregateExec - checkAnswer(sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2"), - Row("str1") :: Nil) + var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: SortAggregateExec => true }.isDefined) + checkAnswer(df, Row("str1") :: Nil) // test ObjectHashAggregateExec - checkAnswer(sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum" + - " from t2"), Row(Array(4), 4) :: Nil) + df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: ObjectHashAggregateExec => true }.isDefined) + checkAnswer(df, Row(Array(4), 4) :: Nil) } } }