Skip to content
Browse files

[SPARK-22951][SQL][BRANCH-2.2] fix aggregation after dropDuplicates o…

…n empty dataframes

## What changes were proposed in this pull request?

(courtesy of liancheng)

Spark SQL supports both global aggregation and grouping aggregation. Global aggregation always return a single row with the initial aggregation state as the output, even there are zero input rows. Spark implements this by simply checking the number of grouping keys and treats an aggregation as a global aggregation if it has zero grouping keys.

However, this simple principle drops the ball in the following case:

spark.emptyDataFrame.dropDuplicates().agg(count($"*") as "c").show()
// +---+
// | c |
// +---+
// | 1 |
// +---+

The reason is that:

1. `df.dropDuplicates()` is roughly translated into something equivalent to:

val allColumns = { col }
df.groupBy(allColumns: _*).agg(allColumns.head, allColumns.tail: _*)

This translation is implemented in the rule `ReplaceDeduplicateWithAggregate`.

2. `spark.emptyDataFrame` contains zero columns and zero rows.

Therefore, rule `ReplaceDeduplicateWithAggregate` makes a confusing transformation roughly equivalent to the following one:

=> spark.emptyDataFrame.groupBy().agg(Map.empty[String, String])

The above transformation is confusing because the resulting aggregate operator contains no grouping keys (because `emptyDataFrame` contains no columns), and gets recognized as a global aggregation. As a result, Spark SQL allocates a single row filled by the initial aggregation state and uses it as the output, and returns a wrong result.

To fix this issue, this PR tweaks `ReplaceDeduplicateWithAggregate` by appending a literal `1` to the grouping key list of the resulting `Aggregate` operator when the input plan contains zero output columns. In this way, `spark.emptyDataFrame.dropDuplicates()` is now translated into a grouping aggregation, roughly depicted as:

=> spark.emptyDataFrame.groupBy(lit(1)).agg(Map.empty[String, String])

Which is now properly treated as a grouping aggregation and returns the correct answer.

## How was this patch tested?

New unit tests added

Closes #23434 from dongjoon-hyun/SPARK-22951-2.

Authored-by: Feng Liu <>
Signed-off-by: Dongjoon Hyun <>
  • Loading branch information...
Feng Liu authored and dongjoon-hyun committed Jan 3, 2019
1 parent 6f435e9 commit f0c6f1da39550ac0cae274348dc885eba861c383
@@ -1165,7 +1165,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
Alias(new First(attr).toAggregateExpression(),
Aggregate(keys, aggCols, child)
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
// aggregations by checking the number of grouping keys. The key difference here is that a
// global aggregation always returns at least one row even if there are no input rows. Here
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
Aggregate(nonemptyKeys, aggCols, child)

@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -94,6 +94,14 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)

test("add one grouping key if necessary when replace Deduplicate with Aggregate") {
val input = LocalRelation()
val query = Deduplicate(Seq.empty, input, streaming = false) // dropDuplicates()
val optimized = Optimize.execute(query.analyze)
val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input)
comparePlans(optimized, correctAnswer)

test("don't replace streaming Deduplicate") {
val input = LocalRelation(', '
val attrA = input.output(0)
@@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.DecimalData
import org.apache.spark.sql.types.{Decimal, DecimalType}
import org.apache.spark.sql.types.DecimalType

case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)

@@ -453,7 +453,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

test("null moments") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")

emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
Row(null, null, null, null, null))
@@ -608,4 +607,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
assert(exchangePlans.length == 1)

Seq(true, false).foreach { codegen =>
test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " +
s"results when codegen is enabled: $codegen") {
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) {
// explicit global aggregations
val emptyAgg = Map.empty[String, String]
checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0)))
checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row()))
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row()))
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0)))

// global aggregation is converted to grouping aggregation:
assert(spark.emptyDataFrame.dropDuplicates().count() == 0)

0 comments on commit f0c6f1d

Please sign in to comment.
You can’t perform that action at this time.