Skip to content

Commit

Permalink
Add more tests and deal with aggregate.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 12, 2018
1 parent 6eda8d2 commit 8432b00
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
Expand Up @@ -1176,7 +1176,7 @@ class Analyzer(
case a @ Aggregate(groupExprs, aggExprs, child) =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
Expand Down Expand Up @@ -1496,7 +1496,11 @@ class Analyzer(

// Try resolving the ordering as though it is in the aggregate clause.
try {
val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
// If a sort order is unresolved, containing references not in aggregate, or containing
// `AggregateExpression`, we need to push down it to the underlying aggregate operator.
val unresolvedSortOrders = sortOrder.filter { s =>
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
}
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Expand Up @@ -2397,5 +2397,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val sort1 = df.select(df("name")).orderBy(df("id"))
val sort2 = df.select(col("name")).orderBy(col("id"))
checkAnswer(sort1, sort2.collect())

withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
checkAnswer(aggPlusSort1, aggPlusSort2.collect())

val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
}
}
}

0 comments on commit 8432b00

Please sign in to comment.