From dc40084beaf48f7cbd3aa206354126143a606be9 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 11 Jan 2023 19:14:30 -0800 Subject: [PATCH 1/2] refactor --- .../catalyst/analysis/ResolveGroupByAll.scala | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala index d45ea412031b1..4b39cb1f646d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala @@ -47,6 +47,24 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { } } + /** + * Returns all the grouping expressions inferred from a GROUP BY ALL aggregate. + * The result is optional. If Spark fails to infer the grouping columns, it is None. + * Otherwise, it contains all the non-aggregate expressions from the project list of the input + * Aggregate. + */ + private def getGroupingExprs(a: Aggregate): Option[Seq[Expression]] = { + val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) + // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or + // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will + // not automatically infer the grouping column to be "i". + if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { + None + } else { + Some(groupingExprs) + } + } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) { case a: Aggregate @@ -54,18 +72,15 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping // expression list (because we don't know they would be aggregate expressions until resolved). - val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) + val groupingExprs = getGroupingExprs(a) - // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or - // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will - // not automatically infer the grouping column to be "i". - if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { - // Case (2): don't replace the ALL. We will eventually tell the user in checkAnalysis - // that we cannot resolve the all in group by. + if (groupingExprs.isEmpty) { + // Don't replace the ALL when we fail to infer the grouping columns. We will eventually + // tell the user in checkAnalysis that we cannot resolve the all in group by. a } else { - // Case (1): this is a valid global aggregate. - a.copy(groupingExpressions = groupingExprs) + // This is a valid GROUP BY ALL aggregate. + a.copy(groupingExpressions = groupingExprs.get) } } @@ -94,8 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { */ def checkAnalysis(operator: LogicalPlan): Unit = operator match { case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - val noAgg = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) - if (noAgg.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { + if (getGroupingExprs(a).isEmpty) { operator.failAnalysis( errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", messageParameters = Map.empty) From f2852fa6a2a43b6a49524accdbcc918c84b6b311 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 11 Jan 2023 19:30:04 -0800 Subject: [PATCH 2/2] rename method --- .../spark/sql/catalyst/analysis/ResolveGroupByAll.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala index 4b39cb1f646d8..8c6ba20cd1af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala @@ -53,7 +53,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { * Otherwise, it contains all the non-aggregate expressions from the project list of the input * Aggregate. */ - private def getGroupingExprs(a: Aggregate): Option[Seq[Expression]] = { + private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = { val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will @@ -72,7 +72,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping // expression list (because we don't know they would be aggregate expressions until resolved). - val groupingExprs = getGroupingExprs(a) + val groupingExprs = getGroupingExpressions(a) if (groupingExprs.isEmpty) { // Don't replace the ALL when we fail to infer the grouping columns. We will eventually @@ -109,7 +109,7 @@ object ResolveGroupByAll extends Rule[LogicalPlan] { */ def checkAnalysis(operator: LogicalPlan): Unit = operator match { case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - if (getGroupingExprs(a).isEmpty) { + if (getGroupingExpressions(a).isEmpty) { operator.failAnalysis( errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", messageParameters = Map.empty)