Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-42003][SQL] Reduce duplicate code in ResolveGroupByAll #39523

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,40 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
}
}

/**
* Returns all the grouping expressions inferred from a GROUP BY ALL aggregate.
* The result is optional. If Spark fails to infer the grouping columns, it is None.
* Otherwise, it contains all the non-aggregate expressions from the project list of the input
* Aggregate.
*/
private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = {
val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
// If the grouping exprs are empty, this could either be (1) a valid global aggregate, or
// (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will
// not automatically infer the grouping column to be "i".
if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
None
} else {
Some(groupingExprs)
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) {
case a: Aggregate
if a.child.resolved && a.aggregateExpressions.forall(_.resolved) && matchToken(a) =>
// Only makes sense to do the rewrite once all the aggregate expressions have been resolved.
// Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping
// expression list (because we don't know they would be aggregate expressions until resolved).
val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
val groupingExprs = getGroupingExpressions(a)

// If the grouping exprs are empty, this could either be (1) a valid global aggregate, or
// (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will
// not automatically infer the grouping column to be "i".
if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
// Case (2): don't replace the ALL. We will eventually tell the user in checkAnalysis
// that we cannot resolve the all in group by.
if (groupingExprs.isEmpty) {
// Don't replace the ALL when we fail to infer the grouping columns. We will eventually
// tell the user in checkAnalysis that we cannot resolve the all in group by.
a
} else {
// Case (1): this is a valid global aggregate.
a.copy(groupingExpressions = groupingExprs)
// This is a valid GROUP BY ALL aggregate.
a.copy(groupingExpressions = groupingExprs.get)
}
}

Expand Down Expand Up @@ -94,8 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] {
*/
def checkAnalysis(operator: LogicalPlan): Unit = operator match {
case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) =>
val noAgg = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate))
if (noAgg.isEmpty && a.aggregateExpressions.exists(containsAttribute)) {
if (getGroupingExpressions(a).isEmpty) {
operator.failAnalysis(
errorClass = "UNRESOLVED_ALL_IN_GROUP_BY",
messageParameters = Map.empty)
Expand Down