Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10389][SQL] support order by non-attribute grouping expression on Aggregate #8548

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -560,43 +560,47 @@ class Analyzer(
filter
}

case sort @ Sort(sortOrder, global,
aggregate @ Aggregate(grouping, originalAggExprs, child))
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
if aggregate.resolved && !sort.resolved =>

// Try resolving the ordering as though it is in the aggregate clause.
try {
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions

// Expressions that have an aggregate can be pushed down.
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)

// Attribute references, that are missing from the order but are present in the grouping
// expressions can also be pushed down.
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
val missingAttributes = requiredAttributes -- aggregate.outputSet
val validPushdownAttributes =
missingAttributes.filter(a => grouping.exists(a.semanticEquals))

// If resolution was successful and we see the ordering either has an aggregate in it or
// it is missing something that is projected away by the aggregate, add the ordering
// the original aggregate operator.
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
}
val aggExprsWithOrdering: Seq[NamedExpression] =
resolvedAggregateOrdering ++ originalAggExprs

Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
} else {
sort
val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]

// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
// Aggregate.
checkAnalysis(resolvedAggregate)

val originalAggExprs = aggregate.aggregateExpressions.map(
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])

// If the ordering expression is same with original aggregate expression, we don't need
// to push down this ordering expression and can reference the original aggregate
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
case (evaluated, order) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}

if (index == -1) {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
} else {
order.copy(child = originalAggExprs(index).toAttribute)
}
}

Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
Expand All @@ -605,9 +609,7 @@ class Analyzer(
}

protected def containsAggregate(condition: Expression): Boolean = {
condition
.collect { case ae: AggregateExpression => ae }
.nonEmpty
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}

Expand Down
19 changes: 15 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1712,9 +1712,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("SPARK-10130 type coercion for IF should have children resolved first") {
val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
df.registerTempTable("src")
checkAnswer(
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
withTempTable("src") {
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
checkAnswer(
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
}
}

test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
withTempTable("src") {
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
Seq(Row(1), Row(1)))
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
Seq(Row(1), Row(1)))
}
}
}