diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala index d45ea412031b1..8c6ba20cd1af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala @@ -47,6 +47,24 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { } } + /** + * Returns all the grouping expressions inferred from a GROUP BY ALL aggregate. + * The result is optional. If Spark fails to infer the grouping columns, it is None. + * Otherwise, it contains all the non-aggregate expressions from the project list of the input + * Aggregate. + */ + private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = { + val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) + // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or + // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will + // not automatically infer the grouping column to be "i". + if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { + None + } else { + Some(groupingExprs) + } + } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) { case a: Aggregate @@ -54,18 +72,15 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping // expression list (because we don't know they would be aggregate expressions until resolved). - val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) + val groupingExprs = getGroupingExpressions(a) - // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or - // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will - // not automatically infer the grouping column to be "i". - if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { - // Case (2): don't replace the ALL. We will eventually tell the user in checkAnalysis - // that we cannot resolve the all in group by. + if (groupingExprs.isEmpty) { + // Don't replace the ALL when we fail to infer the grouping columns. We will eventually + // tell the user in checkAnalysis that we cannot resolve the all in group by. a } else { - // Case (1): this is a valid global aggregate. - a.copy(groupingExpressions = groupingExprs) + // This is a valid GROUP BY ALL aggregate. + a.copy(groupingExpressions = groupingExprs.get) } } @@ -94,8 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { */ def checkAnalysis(operator: LogicalPlan): Unit = operator match { case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - val noAgg = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) - if (noAgg.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { + if (getGroupingExpressions(a).isEmpty) { operator.failAnalysis( errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", messageParameters = Map.empty)