diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ee02d501d10f0..591747b45c376 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -567,22 +567,40 @@ class Analyzer( try { val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - val resolvedOrdering = resolvedOperator.aggregateExpressions + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] // If we pass the analysis check, then the ordering expressions should only reference to // aggregate expressions or grouping expressions, and it's safe to push them down to // Aggregate. - checkAnalysis(resolvedOperator) - // todo: some ordering expressions can be evaluated with existing aggregate expressions - // and we don't need to push them down to Aggregate. - val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedOrdering).map { - case (order, evaluated) => order.copy(child = evaluated.toAttribute) + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } + + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } } - val aggExprsWithOrdering = aggregate.aggregateExpressions ++ resolvedOrdering + Project(aggregate.output, Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan.