Skip to content

Commit

Permalink
[SPARK-29708][SQL][2.4] Correct aggregated values when grouping sets …
Browse files Browse the repository at this point in the history
…are duplicated

### What changes were proposed in this pull request?

This pr intends to fix wrong aggregated values in `GROUPING SETS` when there are duplicated grouping sets in a query (e.g., `GROUPING SETS ((k1),(k1))`).

For example;
```
scala> spark.table("t").show()
+---+---+---+
| k1| k2|  v|
+---+---+---+
|  0|  0|  3|
+---+---+---+

scala> sql("""select grouping_id(), k1, k2, sum(v) from t group by grouping sets ((k1),(k1,k2),(k2,k1),(k1,k2))""").show()
+-------------+---+----+------+
|grouping_id()| k1|  k2|sum(v)|
+-------------+---+----+------+
|            0|  0|   0|     9| <---- wrong aggregate value and the correct answer is `3`
|            1|  0|null|     3|
+-------------+---+----+------+

// PostgreSQL case
postgres=#  select k1, k2, sum(v) from t group by grouping sets ((k1),(k1,k2),(k2,k1),(k1,k2));
 k1 |  k2  | sum
----+------+-----
  0 |    0 |   3
  0 |    0 |   3
  0 |    0 |   3
  0 | NULL |   3
(4 rows)

// Hive case
hive> select GROUPING__ID, k1, k2, sum(v) from t group by k1, k2 grouping sets ((k1),(k1,k2),(k2,k1),(k1,k2));
1	0	NULL	3
0	0	0	3
```
[MS SQL Server has the same behaviour with PostgreSQL](#26961 (comment)). This pr follows the behaviour of PostgreSQL/SQL server; it adds one more virtual attribute in `Expand` for avoiding wrongly grouping rows with the same grouping ID.

This is the #26961 backport  for `branch-2.4`

### Why are the changes needed?

To fix bugs.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

The existing tests.

Closes #27229 from maropu/SPARK-29708-BRANCHC2.4.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
maropu authored and dongjoon-hyun committed Jan 16, 2020
1 parent d6261a1 commit 60a908e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 6 deletions.
Expand Up @@ -670,11 +670,14 @@ object Expand {
child: LogicalPlan): Expand = {
val attrMap = groupByAttrs.zipWithIndex.toMap

val hasDuplicateGroupingSets = groupingSetsAttrs.size !=
groupingSetsAttrs.map(_.map(_.exprId).toSet).distinct.size

// Create an array of Projections for the child projection, and replace the projections'
// expressions which equal GroupBy expressions with Literal(null), if those expressions
// are not set for this grouping set.
val projections = groupingSetsAttrs.map { groupingSetAttrs =>
child.output ++ groupByAttrs.map { attr =>
val projections = groupingSetsAttrs.zipWithIndex.map { case (groupingSetAttrs, i) =>
val projAttrs = child.output ++ groupByAttrs.map { attr =>
if (!groupingSetAttrs.contains(attr)) {
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Expand All @@ -684,11 +687,25 @@ object Expand {
}
// groupingId is the last output, here we use the bit mask as the concrete value for it.
} :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType)

if (hasDuplicateGroupingSets) {
// If `groupingSetsAttrs` has duplicate entries (e.g., GROUPING SETS ((key), (key))),
// we add one more virtual grouping attribute (`_gen_grouping_pos`) to avoid
// wrongly grouping rows with the same grouping ID.
projAttrs :+ Literal.create(i, IntegerType)
} else {
projAttrs
}
}

// the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
// grouping expression or null, so here we create new instance of it.
val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
val output = if (hasDuplicateGroupingSets) {
val gpos = AttributeReference("_gen_grouping_pos", IntegerType, false)()
child.output ++ groupByAttrs.map(_.newInstance) :+ gid :+ gpos
} else {
child.output ++ groupByAttrs.map(_.newInstance) :+ gid
}
Expand(projections, output, Project(child.output ++ groupByAliases, child))
}
}
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql
Expand Up @@ -51,3 +51,9 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE;

SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (());

-- duplicate entries in grouping sets
SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1));

SELECT grouping__id, k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1));

SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1));
47 changes: 44 additions & 3 deletions sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 18


-- !query 0
Expand Down Expand Up @@ -110,8 +110,10 @@ SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUP
-- !query 10 schema
struct<(a + b):int,b:int,sum(c):bigint>
-- !query 10 output
2 NULL 2
4 NULL 4
2 NULL 1
2 NULL 1
4 NULL 2
4 NULL 2
NULL 1 1
NULL 2 2

Expand Down Expand Up @@ -164,3 +166,42 @@ struct<>
-- !query 14 output
org.apache.spark.sql.AnalysisException
expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;


-- !query 15
SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1))
-- !query 15 schema
struct<k1:int,k2:int,avg(v):double>
-- !query 15 output
1 1 1.0
1 1 1.0
1 NULL 1.0
2 2 2.0
2 2 2.0
2 NULL 2.0


-- !query 16
SELECT grouping__id, k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1))
-- !query 16 schema
struct<grouping__id:int,k1:int,k2:int,avg(v):double>
-- !query 16 output
0 1 1 1.0
0 1 1 1.0
0 2 2 2.0
0 2 2 2.0
1 1 NULL 1.0
1 2 NULL 2.0


-- !query 17
SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1))
-- !query 17 schema
struct<grouping(k1):tinyint,k1:int,k2:int,avg(v):double>
-- !query 17 output
0 1 1 1.0
0 1 1 1.0
0 1 NULL 1.0
0 2 2 2.0
0 2 2 2.0
0 2 NULL 2.0

0 comments on commit 60a908e

Please sign in to comment.