Skip to content

Commit

Permalink
[SPARK-31620][SQL] Fix reference binding failure in case of an final …
Browse files Browse the repository at this point in the history
…agg contains subquery

### What changes were proposed in this pull request?

Instead of using `child.output` directly, we should use `inputAggBufferAttributes` from the current agg expression  for `Final` and `PartialMerge` aggregates to bind references for their `mergeExpression`.

### Why are the changes needed?

When planning aggregates, the partial aggregate uses agg fucs' `inputAggBufferAttributes` as its output, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala#L105

For final `HashAggregateExec`, we need to bind the `DeclarativeAggregate.mergeExpressions` with the output of the partial aggregate operator, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L348

This is usually fine. However, if we copy the agg func somehow after agg planning, like `PlanSubqueries`, the `DeclarativeAggregate` will be replaced by a new instance with new `inputAggBufferAttributes` and `mergeExpressions`. Then we can't bind the `mergeExpressions` with the output of the partial aggregate operator, as it uses the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy.

Note that, `ImperativeAggregate` doesn't have this problem, as we don't need to bind its `mergeExpressions`. It has a different mechanism to access buffer values, via `mutableAggBufferOffset` and `inputAggBufferOffset`.

### Does this PR introduce _any_ user-facing change?

Yes, user hit error previously but run query successfully after this change.

### How was this patch tested?

Added a regression test.

Closes #28496 from Ngone51/spark-31620.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed May 15, 2020
1 parent 194ac3b commit d8b001f
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 6 deletions.
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode}

/**
Expand All @@ -40,4 +40,28 @@ trait BaseAggregateExec extends UnaryExecNode {
|${ExplainUtils.generateFieldString("Results", resultExpressions)}
|""".stripMargin
}

protected def inputAttributes: Seq[Attribute] = {
val modes = aggregateExpressions.map(_.mode).distinct
if (modes.contains(Final) || modes.contains(PartialMerge)) {
// SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's
// `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the
// output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the
// aggregate function somehow after aggregate planning, like `PlanSubqueries`, the
// `DeclarativeAggregate` will be replaced by a new instance with new
// `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate
// can't bind the `mergeExpressions` with the output of the partial aggregate, as they use
// the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead,
// we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`.
val aggAttrs = aggregateExpressions
// there're exactly four cases needs `inputAggBufferAttributes` from child according to the
// agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final,
// Partial -> PartialMerge, PartialMerge -> PartialMerge.
.filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction)
.flatMap(_.inputAggBufferAttributes)
child.output.dropRight(aggAttrs.length) ++ aggAttrs
} else {
child.output
}
}
}
Expand Up @@ -129,7 +129,7 @@ case class HashAggregateExec(
resultExpressions,
(expressions, inputSchema) =>
MutableProjection.create(expressions, inputSchema),
child.output,
inputAttributes,
iter,
testFallbackStartsAt,
numOutputRows,
Expand Down Expand Up @@ -334,7 +334,7 @@ case class HashAggregateExec(
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes
// To individually generate code for each aggregate function, an element in `updateExprs` holds
// all the expressions for the buffer of an aggregation function.
val updateExprs = aggregateExpressions.map { e =>
Expand Down Expand Up @@ -931,7 +931,7 @@ case class HashAggregateExec(
}
}

val inputAttr = aggregateBufferAttributes ++ child.output
val inputAttr = aggregateBufferAttributes ++ inputAttributes
// Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
// generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
// generating input columns, we use `currentVars`.
Expand Down
Expand Up @@ -123,7 +123,7 @@ case class ObjectHashAggregateExec(
resultExpressions,
(expressions, inputSchema) =>
MutableProjection.create(expressions, inputSchema),
child.output,
inputAttributes,
iter,
fallbackCountThreshold,
numOutputRows)
Expand Down
Expand Up @@ -88,7 +88,7 @@ case class SortAggregateExec(
val outputIter = new SortBasedAggregationIterator(
partIndex,
groupingExpressions,
child.output,
inputAttributes,
iter,
aggregateExpressions,
aggregateAttributes,
Expand Down
Expand Up @@ -973,4 +973,43 @@ class DataFrameAggregateSuite extends QueryTest
assert(error.message.contains("function count_if requires boolean type"))
}
}

Seq(true, false).foreach { value =>
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")
sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)")

// test without grouping keys
checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"),
Row(4) :: Nil)

// test with grouping keys
checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " +
"t2 group by c"), Row(3, 4) :: Nil)

// test with distinct
checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," +
" d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil)

// test subquery with agg
checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," +
" d, 0))) as csum from t2 group by c"), Row(4) :: Nil)

// test SortAggregateExec
var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2")
assert(df.queryExecution.executedPlan
.find { case _: SortAggregateExec => true }.isDefined)
checkAnswer(df, Row("str1") :: Nil)

// test ObjectHashAggregateExec
df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2")
assert(df.queryExecution.executedPlan
.find { case _: ObjectHashAggregateExec => true }.isDefined)
checkAnswer(df, Row(Array(4), 4) :: Nil)
}
}
}
}
}

0 comments on commit d8b001f

Please sign in to comment.