From 297d06728bc41db5aae9aa7d8ae9aa05f9f6b0c0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Mar 2021 12:52:59 +0800 Subject: [PATCH] Simplify ResolveAggregateFunctions --- .../sql/catalyst/analysis/Analyzer.scala | 285 ++++++++---------- .../sql/catalyst/analysis/AnalysisSuite.scala | 8 +- .../ResolveGroupingAnalyticsSuite.scala | 8 +- 3 files changed, 125 insertions(+), 176 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68a3f3d5aa4b3..8208a6285f5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -661,37 +660,36 @@ class Analyzer(override val catalogManager: CatalogManager) g.aggregations, g.child) } // Try resolving the condition of the filter as though it is in the aggregate clause - val resolvedInfo = - ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving) + val (extraAggExprs, Seq(resolvedHavingCond)) = + ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(h.havingCondition), aggForResolving) // Push the aggregate expressions into the aggregate (if any). - if (resolvedInfo.nonEmpty) { - val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get - val newChild = h.child match { - case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => - constructAggregate( - cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) - case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => - constructAggregate( - rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) - case x: GroupingSets => - constructAggregate( - x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child) - } - - // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the - // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the - // condition again. - val exprMap = extraAggExprs.zip( - newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight( - extraAggExprs.length)).toMap - val newCond = resolvedHavingCond.transform { - case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne) - } + val newChild = h.child match { + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + constructAggregate( + cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + constructAggregate( + rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) + case x: GroupingSets => + constructAggregate( + x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child) + } + + // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the + // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the + // condition again. + val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute)) + .zip(newChild.output)) + val newCond = resolvedHavingCond.transform { + case a: Attribute => attrMap.getOrElse(a, a) + } + + if (extraAggExprs.isEmpty) { + Filter(newCond, newChild) + } else { Project(newChild.output.dropRight(extraAggExprs.length), Filter(newCond, newChild)) - } else { - h } } @@ -2491,162 +2489,115 @@ class Analyzer(override val catalogManager: CatalogManager) // resolve the having condition expression, here we skip resolving it in ResolveReferences // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => - resolveHaving(Filter(cond, agg), agg) - - case f @ Filter(_, agg: Aggregate) if agg.resolved => - resolveHaving(f, agg) + // HAVING can only use aggregate functions and grouping columns, so we can't resolve the + // references based on `agg.output`. + resolveOperatorWithAggregate(Seq(cond), agg, (resolvedExprs, newChild) => { + Filter(resolvedExprs.head, newChild) + }) - case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => + case Filter(cond, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = resolveExpressionByPlanOutput(cond, agg) + resolveOperatorWithAggregate(Seq(maybeResolved), agg, (resolvedExprs, newChild) => { + Filter(resolvedExprs.head, newChild) + }) - // Try resolving the ordering as though it is in the aggregate clause. - try { - // If a sort order is unresolved, containing references not in aggregate, or containing - // `AggregateExpression`, we need to push down it to the underlying aggregate operator. - val unresolvedSortOrders = sortOrder.filter { s => - !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => + // We should resolve the references normally based on child.output first. + val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) + resolveOperatorWithAggregate(maybeResolved, agg, (resolvedExprs, newChild) => { + val newSortOrder = sortOrder.zip(resolvedExprs).map { + case (sortOrder, expr) => sortOrder.copy(child = expr) } - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - - val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) - - val resolvedAggregate: Aggregate = - executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] - - val (reResolvedAggExprs, resolvedAliasedOrdering) = - resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = - resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = reResolvedAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } + Sort(newSortOrder, global, newChild) + }) + } - if (index == -1) { - if (hasCharVarchar(evaluated)) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) - } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } + def resolveExprsWithAggregate( + exprs: Seq[Expression], + agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { + val aggregateExpressions = ArrayBuffer.empty[NamedExpression] + val transformed = exprs.map { e => + // Try resolving the expression as though it is in the aggregate clause. + val maybeResolved = resolveExpressionByPlanOutput(e, agg.child) + if (maybeResolved.resolved && maybeResolved.references.subsetOf(agg.outputSet) && + !containsAggregate(maybeResolved)) { + // The given expression is valid and doesn't need extra resolution. + maybeResolved + } else if (containsUnresolvedFunc(maybeResolved)) { + // The given expression has unresolved functions which may be aggregate functions and we + // need to wait for other rules to resolve the functions first. + maybeResolved + } else { + // Avoid adding an extra aggregate expression if it's already present in + // `agg.aggregateExpressions`. + val index = if (maybeResolved.resolved) { + agg.aggregateExpressions.indexWhere { + case Alias(child, _) => child semanticEquals maybeResolved + case other => other semanticEquals maybeResolved + } + } else { + -1 } - - val sortOrdersMap = unresolvedSortOrders - .map(new TreeNodeRef(_)) - .zip(evaluatedOrderings) - .toMap - val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) - - // Since we don't rely on sort.resolved as the stop condition for this rule, - // we need to check this and prevent applying this rule multiple times - if (sortOrder == finalSortOrders) { - sort + if (index >= 0) { + agg.aggregateExpressions(index).toAttribute } else { - Project(aggregate.output, - Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + buildAggExprList(maybeResolved, agg, aggregateExpressions) } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => sort } + } + (aggregateExpressions.toSeq, transformed) } - def hasCharVarchar(expr: Alias): Boolean = { - expr.find { - case ne: NamedExpression => CharVarcharUtils.getRawType(ne.metadata).nonEmpty - case _ => false - }.nonEmpty + private def buildAggExprList( + expr: Expression, + agg: Aggregate, + aggExprList: ArrayBuffer[NamedExpression]): Expression = expr match { + case ae: AggregateExpression if ae.resolved => + val alias = Alias(ae, ae.toString)() + aggExprList += alias + alias.toAttribute + // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. + case grouping: Expression if grouping.resolved && + agg.groupingExpressions.exists(_.semanticEquals(grouping)) && + !ResolveGroupingAnalytics.hasGroupingFunction(grouping) && + !agg.output.exists(_.semanticEquals(grouping)) => + grouping match { + case ne: NamedExpression => + aggExprList += ne + ne.toAttribute + case _ => + val alias = Alias(grouping, grouping.toString)() + aggExprList += alias + alias.toAttribute + } + case a: Attribute if agg.child.outputSet.contains(a) => + // Undo the resolution. This attribute is neither inside aggregate functions nor a + // grouping column. It shouldn't be resolved with `agg.child.output`. + UnresolvedAttribute(Seq(a.name)) + case other if other.resolved => + other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList))) + case _ => expr } - def containsAggregate(condition: Expression): Boolean = { - condition.find(_.isInstanceOf[AggregateExpression]).isDefined + def containsAggregate(expr: Expression): Boolean = { + expr.find(_.isInstanceOf[AggregateExpression]).isDefined } - def resolveFilterCondInAggregate( - filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = { - try { - val aggregatedCondition = - Aggregate( - agg.groupingExpressions, - Alias(filterCond, "havingCondition")() :: Nil, - agg.child) - val resolvedOperator = executeSameContext(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved) { - // Try to replace all aggregate expressions in the filter by an alias. - val aggregateExpressions = ArrayBuffer.empty[NamedExpression] - val transformedAggregateFilter = resolvedAggregateFilter.transform { - case ae: AggregateExpression => - val alias = Alias(ae, ae.toString)() - aggregateExpressions += alias - alias.toAttribute - // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. - case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) && - !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !agg.output.exists(_.semanticEquals(e)) => - e match { - case ne: NamedExpression => - aggregateExpressions += ne - ne.toAttribute - case _ => - val alias = Alias(e, e.toString)() - aggregateExpressions += alias - alias.toAttribute - } - } - if (aggregateExpressions.nonEmpty) { - Some(aggregateExpressions.toSeq, transformedAggregateFilter) - } else { - None - } - } else { - None - } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return None and the caller side will return the original plan. - case ae: AnalysisException => None - } + def containsUnresolvedFunc(expr: Expression): Boolean = { + expr.find(_.isInstanceOf[UnresolvedFunction]).isDefined } - def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { - // Try resolving the condition of the filter as though it is in the aggregate clause - val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg) - - // Push the aggregate expressions into the aggregate (if any). - if (resolvedInfo.nonEmpty) { - val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get - Project(agg.output, - Filter(resolvedHavingCond, - agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + def resolveOperatorWithAggregate( + exprs: Seq[Expression], + agg: Aggregate, + buildOperator: (Seq[Expression], Aggregate) => LogicalPlan): LogicalPlan = { + val (extraAggExprs, resolvedExprs) = resolveExprsWithAggregate(exprs, agg) + if (extraAggExprs.isEmpty) { + buildOperator(resolvedExprs, agg) } else { - filter + Project(agg.output, buildOperator(resolvedExprs, agg.copy( + aggregateExpressions = agg.aggregateExpressions ++ extraAggExprs))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 38f92c68af020..1433786a13857 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -170,10 +170,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { val b = testRelation2.output(1) val c = testRelation2.output(2) val alias_a3 = count(a).as("a3") - val alias_b = b.as("aggOrder") // Case 1: when the child of Sort is not Aggregate, - // the sort reference is handled by the rule ResolveSortReferences + // the sort reference is handled by the rule ResolveMissingReferences val plan1 = testRelation2 .groupBy($"a", $"c", $"b")($"a", $"c", count($"a").as("a3")) .select($"a", $"c", $"a3") @@ -194,8 +193,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { .orderBy($"b".asc) val expected2 = testRelation2 - .groupBy(a, c, b)(a, c, alias_a3, alias_b) - .orderBy(alias_b.toAttribute.asc) + .groupBy(a, c, b)(a, c, alias_a3, b) + .orderBy(b.asc) .select(a, c, alias_a3.toAttribute) checkAnalysis(plan2, expected2) @@ -415,7 +414,6 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected = testRelation2 .groupBy(a, c)(alias1, alias2, alias3) .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) - .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index cdfae14138290..4fbfa1bfeaa78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -287,9 +287,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) val expected = Project(Seq(a, b), Sort( - Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true, + Seq(SortOrder(grouping_a, Ascending)), true, Aggregate(Seq(a, b, gid), - Seq(a, b, grouping_a.as("aggOrder")), + Seq(a, b, gid), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), @@ -308,9 +308,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) val expected3 = Project(Seq(a, b), Sort( - Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true, + Seq(SortOrder(gid, Ascending)), true, Aggregate(Seq(a, b, gid), - Seq(a, b, gid.as("aggOrder")), + Seq(a, b, gid), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)),