Skip to content

Commit

Permalink
[SPARK-33308][SQL] Refactor current grouping analytics
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
As discussed in
#30145 (comment)
#30145 (comment)

We need to rewrite current Grouping Analytics grammar to support  as flexible as Postgres SQL to support subsequent development.
In  postgres sql, it support
```
select a, b, c, count(1) from t group by cube (a, b, c);
select a, b, c, count(1) from t group by cube(a, b, c);
select a, b, c, count(1) from t group by cube (a, b, c, (a, b), (a, b, c));
select a, b, c, count(1) from t group by rollup(a, b, c);
select a, b, c, count(1) from t group by rollup (a, b, c);
select a, b, c, count(1) from t group by rollup (a, b, c, (a, b), (a, b, c));
```
In this pr,  we have done three things as below, and we will split it to different pr:

 - Refactor CUBE/ROLLUP (regarding them as ANTLR tokens in a parser)
 - Refactor GROUPING SETS (the logical node -> a new expr)
 - Support new syntax for CUBE/ROLLUP (e.g., GROUP BY CUBE ((a, b), (a, c)))

### Why are the changes needed?
Rewrite current Grouping Analytics grammar to support  as flexible as Postgres SQL to support subsequent development.

### Does this PR introduce _any_ user-facing change?
User can  write Grouping Analytics grammar as flexible as Postgres SQL to support subsequent development.

### How was this patch tested?
Added UT

Closes #30212 from AngersZhuuuu/refact-grouping-analytics.

Lead-authored-by: angerszhu <angers.zhu@gmail.com>
Co-authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com>
Co-authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and cloud-fan committed Mar 30, 2021
1 parent 935aa8c commit a98dc60
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 257 deletions.
44 changes: 27 additions & 17 deletions docs/sql-ref-syntax-qry-select-groupby.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ When a FILTER clause is attached to an aggregate function, only the matching row
GROUP BY group_expression [ , group_expression [ , ... ] ]
[ { WITH ROLLUP | WITH CUBE | GROUPING SETS (grouping_set [ , ...]) } ]

GROUP BY GROUPING SETS (grouping_set [ , ...])
GROUP BY { group_expression | { ROLLUP | CUBE | GROUPING SETS } (grouping_set [ , ...]) } [ , ... ]
```

While aggregate functions are defined as
Expand All @@ -42,40 +42,50 @@ aggregate_name ( [ DISTINCT ] expression [ , ... ] ) [ FILTER ( WHERE boolean_ex

### Parameters

* **GROUPING SETS**
* **grouping_expression**

Groups the rows for each subset of the expressions specified in the grouping sets. For example,
`GROUP BY GROUPING SETS (warehouse, product)` is semantically equivalent
to union of results of `GROUP BY warehouse` and `GROUP BY product`. This clause
is a shorthand for a `UNION ALL` where each leg of the `UNION ALL`
operator performs aggregation of subset of the columns specified in the `GROUPING SETS` clause.
Specifies the criteria based on which the rows are grouped together. The grouping of rows is performed based on
result values of the grouping expressions. A grouping expression may be a column name like `GROUP BY a`, a column position like
`GROUP BY 0`, or an expression like `GROUP BY a + b`.

* **grouping_set**

A grouping set is specified by zero or more comma-separated expressions in parentheses.
A grouping set is specified by zero or more comma-separated expressions in parentheses. When the
grouping set has only one element, parentheses can be omitted. For example, `GROUPING SETS ((a), (b))`
is the same as `GROUPING SETS (a, b)`.

**Syntax:** `( [ expression [ , ... ] ] )`
**Syntax:** `{ ( [ expression [ , ... ] ] ) | expression }`

* **grouping_expression**
* **GROUPING SETS**

Specifies the criteria based on which the rows are grouped together. The grouping of rows is performed based on
result values of the grouping expressions. A grouping expression may be a column alias, a column position
or an expression.
Groups the rows for each grouping set specified after GROUPING SETS. For example,
`GROUP BY GROUPING SETS ((warehouse), (product))` is semantically equivalent
to union of results of `GROUP BY warehouse` and `GROUP BY product`. This clause
is a shorthand for a `UNION ALL` where each leg of the `UNION ALL`
operator performs aggregation of each grouping set specified in the `GROUPING SETS` clause.
Similarly, `GROUP BY GROUPING SETS ((warehouse, product), (product), ())` is semantically
equivalent to the union of results of `GROUP BY warehouse, product`, `GROUP BY product`
and global aggregate.

* **ROLLUP**

Specifies multiple levels of aggregations in a single statement. This clause is used to compute aggregations
based on multiple grouping sets. `ROLLUP` is a shorthand for `GROUPING SETS`. For example,
`GROUP BY warehouse, product WITH ROLLUP` is equivalent to `GROUP BY GROUPING SETS
((warehouse, product), (warehouse), ())`.
`GROUP BY warehouse, product WITH ROLLUP` or `GROUP BY ROLLUP(warehouse, product)` is equivalent to
`GROUP BY GROUPING SETS((warehouse, product), (warehouse), ())`.
`GROUP BY ROLLUP(warehouse, product, (warehouse, location))` is equivalent to
`GROUP BY GROUPING SETS((warehouse, product, location), (warehouse, product), (warehouse), ())`.
The N elements of a `ROLLUP` specification results in N+1 `GROUPING SETS`.

* **CUBE**

`CUBE` clause is used to perform aggregations based on combination of grouping columns specified in the
`GROUP BY` clause. `CUBE` is a shorthand for `GROUPING SETS`. For example,
`GROUP BY warehouse, product WITH CUBE` is equivalent to `GROUP BY GROUPING SETS
((warehouse, product), (warehouse), (product), ())`.
`GROUP BY warehouse, product WITH CUBE` or `GROUP BY CUBE(warehouse, product)` is equivalent to
`GROUP BY GROUPING SETS((warehouse, product), (warehouse), (product), ())`.
`GROUP BY CUBE(warehouse, product, (warehouse, location))` is equivalent to
`GROUP BY GROUPING SETS((warehouse, product, location), (warehouse, product), (warehouse, location),
(product, warehouse, location), (warehouse), (product), (warehouse, product), ())`.
The N elements of a `CUBE` specification results in 2^N `GROUPING SETS`.

* **aggregate_name**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,11 +588,21 @@ fromClause
;

aggregationClause
: GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
: GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause
(',' groupingExpressionsWithGroupingAnalytics+=groupByClause)*
| GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
WITH kind=ROLLUP
| WITH kind=CUBE
| kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
| GROUP BY kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')'
;

groupByClause
: groupingAnalytics
| expression
;

groupingAnalytics
: (ROLLUP | CUBE | GROUPING SETS) '(' groupingSet (',' groupingSet)* ')'
;

groupingSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,6 @@ class Analyzer(override val catalogManager: CatalogManager)
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)

case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.copy(aggregations = assignAliases(g.aggregations))

case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child)
Expand All @@ -444,40 +441,6 @@ class Analyzer(override val catalogManager: CatalogManager)
}

object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
/*
* GROUP BY a, b, c WITH ROLLUP
* is equivalent to
* GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( ) ).
* Group Count: N + 1 (N is the number of group expressions)
*
* We need to get all of its subsets for the rule described above, the subset is
* represented as sequence of expressions.
*/
def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toIndexedSeq

/*
* GROUP BY a, b, c WITH CUBE
* is equivalent to
* GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ).
* Group Count: 2 ^ N (N is the number of group expressions)
*
* We need to get all of its subsets for a given GROUPBY expression, the subsets are
* represented as sequence of expressions.
*/
def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = {
// `cubeExprs0` is recursive and returns a lazy Stream. Here we call `toIndexedSeq` to
// materialize it and avoid serialization problems later on.
cubeExprs0(exprs).toIndexedSeq
}

def cubeExprs0(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
case x :: xs =>
val initial = cubeExprs0(xs)
initial.map(x +: _) ++ initial
case Nil =>
Seq(Seq.empty)
}

private[analysis] def hasGroupingFunction(e: Expression): Boolean = {
e.collectFirst {
case g: Grouping => g
Expand Down Expand Up @@ -657,14 +620,9 @@ class Analyzer(override val catalogManager: CatalogManager)
val aggForResolving = h.child match {
// For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
// groupingExpressions for condition resolving.
case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) =>
a.copy(groupingExpressions = groupByExprs)
case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) =>
a.copy(groupingExpressions = groupByExprs)
case g: GroupingSets =>
Aggregate(
getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
g.aggregations, g.child)
case a @ Aggregate(Seq(gs: GroupingSet), _, _) =>
a.copy(groupingExpressions =
getFinalGroupByExpressions(gs.groupingSets, gs.groupByExprs))
}
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo =
Expand All @@ -674,15 +632,10 @@ class Analyzer(override val catalogManager: CatalogManager)
if (resolvedInfo.nonEmpty) {
val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
val newChild = h.child match {
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case x: GroupingSets =>
case Aggregate(Seq(gs: GroupingSet), aggregateExpressions, child) =>
constructAggregate(
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
gs.selectedGroupByExprs, gs.groupByExprs,
aggregateExpressions ++ extraAggExprs, child)
}

// Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
Expand All @@ -705,30 +658,16 @@ class Analyzer(override val catalogManager: CatalogManager)
// CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
// Filter/Sort.
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case h @ UnresolvedHaving(
_, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)
case h @ UnresolvedHaving(
_, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)
case h @ UnresolvedHaving(_, g: GroupingSets)
if g.childrenResolved && g.expressions.forall(_.resolved) =>
case h @ UnresolvedHaving(_, agg @ Aggregate(Seq(gs: GroupingSet), aggregateExpressions, _))
if agg.childrenResolved && (gs.groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)

case a if !a.childrenResolved => a // be sure all of the children are resolved.

// Ensure group by expressions and aggregate expressions have been resolved.
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child)
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child)
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child)
case Aggregate(Seq(gs: GroupingSet), aggregateExpressions, child)
if (gs.groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(gs.selectedGroupByExprs, gs.groupByExprs, aggregateExpressions, child)

// We should make sure all expressions in condition have been resolved.
case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
Expand Down Expand Up @@ -1652,26 +1591,6 @@ class Analyzer(override val catalogManager: CatalogManager)

a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child)

// SPARK-31670: Resolve Struct field in selectedGroupByExprs/groupByExprs and aggregations
// will be wrapped with alias like Alias(GetStructField, name) with different ExprId.
// This cause aggregateExpressions can't be replaced by expanded groupByExpressions in
// `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim unnecessary alias
// of GetStructField here.
case g: GroupingSets =>
val resolvedSelectedExprs = g.selectedGroupByExprs
.map(_.map(resolveExpressionByPlanChildren(_, g))
.map(trimTopLevelGetStructFieldAlias))

val resolvedGroupingExprs = g.groupByExprs
.map(resolveExpressionByPlanChildren(_, g))
.map(trimTopLevelGetStructFieldAlias)

val resolvedAggExprs = g.aggregations
.map(resolveExpressionByPlanChildren(_, g))
.map(_.asInstanceOf[NamedExpression])

g.copy(resolvedSelectedExprs, resolvedGroupingExprs, g.child, resolvedAggExprs)

case o: OverwriteByExpression if o.table.resolved =>
// The delete condition of `OverwriteByExpression` will be passed to the table
// implementation and should be resolved based on the table schema.
Expand Down Expand Up @@ -2068,13 +1987,6 @@ class Analyzer(override val catalogManager: CatalogManager)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(!_.resolved) =>
agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child))

case gs @ GroupingSets(selectedGroups, groups, child, aggs)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
gs.copy(
selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)),
groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,6 @@ object FunctionRegistry {
expression[TypeOf]("typeof"),

// grouping sets
expression[Cube]("cube"),
expression[Rollup]("rollup"),
expression[Grouping]("grouping"),
expression[GroupingID]("grouping_id"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID, Now}
import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestamp, GroupingSets, MonotonicallyIncreasingID, Now}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -203,15 +203,22 @@ object UnsupportedOperationChecker extends Logging {
// Operations that cannot exists anywhere in a streaming plan
subPlan match {

case Aggregate(_, aggregateExpressions, child) =>
case Aggregate(groupingExpressions, aggregateExpressions, child) =>
val distinctAggExprs = aggregateExpressions.flatMap { expr =>
expr.collect { case ae: AggregateExpression if ae.isDistinct => ae }
}
val haveGroupingSets = groupingExpressions.exists(_.isInstanceOf[GroupingSets])

throwErrorIf(
child.isStreaming && distinctAggExprs.nonEmpty,
"Distinct aggregations are not supported on streaming DataFrames/Datasets. Consider " +
"using approx_count_distinct() instead.")

throwErrorIf(
child.isStreaming && haveGroupingSets,
"Grouping Sets is not supported on streaming DataFrames/Datasets"
)

case _: Command =>
throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " +
"streaming DataFrames/Datasets")
Expand Down Expand Up @@ -353,9 +360,6 @@ object UnsupportedOperationChecker extends Logging {
case Intersect(left, right, _) if left.isStreaming && right.isStreaming =>
throwError("Intersect between two streaming DataFrames/Datasets is not supported")

case GroupingSets(_, _, child, _) if child.isStreaming =>
throwError("GroupingSets is not supported on streaming DataFrames/Datasets")

case GlobalLimit(_, _) | LocalLimit(_, _)
if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update =>
throwError("Limits are not supported on streaming DataFrames/Datasets in Update " +
Expand Down
Loading

0 comments on commit a98dc60

Please sign in to comment.