Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -661,37 +660,36 @@ class Analyzer(override val catalogManager: CatalogManager)
g.aggregations, g.child)
}
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo =
ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving)
val (extraAggExprs, Seq(resolvedHavingCond)) =
ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(h.havingCondition), aggForResolving)

// Push the aggregate expressions into the aggregate (if any).
if (resolvedInfo.nonEmpty) {
val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
val newChild = h.child match {
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case x: GroupingSets =>
constructAggregate(
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
}

// Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
// aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
// condition again.
val exprMap = extraAggExprs.zip(
newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight(
extraAggExprs.length)).toMap
val newCond = resolvedHavingCond.transform {
case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne)
}
val newChild = h.child match {
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case x: GroupingSets =>
constructAggregate(
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
}

// Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
// aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
// condition again.
val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute))
.zip(newChild.output))
val newCond = resolvedHavingCond.transform {
case a: Attribute => attrMap.getOrElse(a, a)
}

if (extraAggExprs.isEmpty) {
Filter(newCond, newChild)
} else {
Project(newChild.output.dropRight(extraAggExprs.length),
Filter(newCond, newChild))
} else {
h
}
}

Expand Down Expand Up @@ -2491,162 +2489,115 @@ class Analyzer(override val catalogManager: CatalogManager)
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)

case f @ Filter(_, agg: Aggregate) if agg.resolved =>
resolveHaving(f, agg)
// HAVING can only use aggregate functions and grouping columns, so we can't resolve the
// references based on `agg.output`.
resolveOperatorWithAggregate(Seq(cond), agg, (resolvedExprs, newChild) => {
Filter(resolvedExprs.head, newChild)
})

case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
case Filter(cond, agg: Aggregate) if agg.resolved =>
// We should resolve the references normally based on child.output first.
val maybeResolved = resolveExpressionByPlanOutput(cond, agg)
resolveOperatorWithAggregate(Seq(maybeResolved), agg, (resolvedExprs, newChild) => {
Filter(resolvedExprs.head, newChild)
})

// Try resolving the ordering as though it is in the aggregate clause.
try {
// If a sort order is unresolved, containing references not in aggregate, or containing
// `AggregateExpression`, we need to push down it to the underlying aggregate operator.
val unresolvedSortOrders = sortOrder.filter { s =>
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
case Sort(sortOrder, global, agg: Aggregate) if agg.resolved =>
// We should resolve the references normally based on child.output first.
val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg))
resolveOperatorWithAggregate(maybeResolved, agg, (resolvedExprs, newChild) => {
val newSortOrder = sortOrder.zip(resolvedExprs).map {
case (sortOrder, expr) => sortOrder.copy(child = expr)
}
val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())

val aggregateWithExtraOrdering = aggregate.copy(
aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering)

val resolvedAggregate: Aggregate =
executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate]

val (reResolvedAggExprs, resolvedAliasedOrdering) =
resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length)

// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
// Aggregate.
checkAnalysis(resolvedAggregate)

val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases)

// If the ordering expression is same with original aggregate expression, we don't need
// to push down this ordering expression and can reference the original aggregate
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering)
val evaluatedOrderings =
resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map {
case (evaluated, (order, aliasOrder)) =>
val index = reResolvedAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}
Sort(newSortOrder, global, newChild)
})
}

if (index == -1) {
if (hasCharVarchar(evaluated)) {
needsPushDown += aliasOrder
order.copy(child = aliasOrder)
} else {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
}
} else {
order.copy(child = originalAggExprs(index).toAttribute)
}
def resolveExprsWithAggregate(
exprs: Seq[Expression],
agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = {
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformed = exprs.map { e =>
// Try resolving the expression as though it is in the aggregate clause.
val maybeResolved = resolveExpressionByPlanOutput(e, agg.child)
if (maybeResolved.resolved && maybeResolved.references.subsetOf(agg.outputSet) &&
!containsAggregate(maybeResolved)) {
// The given expression is valid and doesn't need extra resolution.
maybeResolved
} else if (containsUnresolvedFunc(maybeResolved)) {
// The given expression has unresolved functions which may be aggregate functions and we
// need to wait for other rules to resolve the functions first.
maybeResolved
} else {
// Avoid adding an extra aggregate expression if it's already present in
// `agg.aggregateExpressions`.
val index = if (maybeResolved.resolved) {
agg.aggregateExpressions.indexWhere {
case Alias(child, _) => child semanticEquals maybeResolved
case other => other semanticEquals maybeResolved
}
} else {
-1
}

val sortOrdersMap = unresolvedSortOrders
.map(new TreeNodeRef(_))
.zip(evaluatedOrderings)
.toMap
val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s))

// Since we don't rely on sort.resolved as the stop condition for this rule,
// we need to check this and prevent applying this rule multiple times
if (sortOrder == finalSortOrders) {
sort
if (index >= 0) {
agg.aggregateExpressions(index).toAttribute
} else {
Project(aggregate.output,
Sort(finalSortOrders, global,
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
buildAggExprList(maybeResolved, agg, aggregateExpressions)
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => sort
}
}
(aggregateExpressions.toSeq, transformed)
}

def hasCharVarchar(expr: Alias): Boolean = {
expr.find {
case ne: NamedExpression => CharVarcharUtils.getRawType(ne.metadata).nonEmpty
case _ => false
}.nonEmpty
private def buildAggExprList(
expr: Expression,
agg: Aggregate,
aggExprList: ArrayBuffer[NamedExpression]): Expression = expr match {
case ae: AggregateExpression if ae.resolved =>
val alias = Alias(ae, ae.toString)()
aggExprList += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case grouping: Expression if grouping.resolved &&
agg.groupingExpressions.exists(_.semanticEquals(grouping)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(grouping) &&
!agg.output.exists(_.semanticEquals(grouping)) =>
grouping match {
case ne: NamedExpression =>
aggExprList += ne
ne.toAttribute
case _ =>
val alias = Alias(grouping, grouping.toString)()
aggExprList += alias
alias.toAttribute
}
case a: Attribute if agg.child.outputSet.contains(a) =>
// Undo the resolution. This attribute is neither inside aggregate functions nor a
// grouping column. It shouldn't be resolved with `agg.child.output`.
UnresolvedAttribute(Seq(a.name))
case other if other.resolved =>
other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList)))
case _ => expr
}

def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
def containsAggregate(expr: Expression): Boolean = {
expr.find(_.isInstanceOf[AggregateExpression]).isDefined
}

def resolveFilterCondInAggregate(
filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = {
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filterCond, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}
if (aggregateExpressions.nonEmpty) {
Some(aggregateExpressions.toSeq, transformedAggregateFilter)
} else {
None
}
} else {
None
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return None and the caller side will return the original plan.
case ae: AnalysisException => None
}
def containsUnresolvedFunc(expr: Expression): Boolean = {
expr.find(_.isInstanceOf[UnresolvedFunction]).isDefined
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)

// Push the aggregate expressions into the aggregate (if any).
if (resolvedInfo.nonEmpty) {
val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
Project(agg.output,
Filter(resolvedHavingCond,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
def resolveOperatorWithAggregate(
exprs: Seq[Expression],
agg: Aggregate,
buildOperator: (Seq[Expression], Aggregate) => LogicalPlan): LogicalPlan = {
val (extraAggExprs, resolvedExprs) = resolveExprsWithAggregate(exprs, agg)
if (extraAggExprs.isEmpty) {
buildOperator(resolvedExprs, agg)
} else {
filter
Project(agg.output, buildOperator(resolvedExprs, agg.copy(
aggregateExpressions = agg.aggregateExpressions ++ extraAggExprs)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val alias_a3 = count(a).as("a3")
val alias_b = b.as("aggOrder")

// Case 1: when the child of Sort is not Aggregate,
// the sort reference is handled by the rule ResolveSortReferences
// the sort reference is handled by the rule ResolveMissingReferences
val plan1 = testRelation2
.groupBy($"a", $"c", $"b")($"a", $"c", count($"a").as("a3"))
.select($"a", $"c", $"a3")
Expand All @@ -194,8 +193,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
.orderBy($"b".asc)

val expected2 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, alias_b)
.orderBy(alias_b.toAttribute.asc)
.groupBy(a, c, b)(a, c, alias_a3, b)
.orderBy(b.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan2, expected2)
Expand Down Expand Up @@ -415,7 +414,6 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val expected = testRelation2
.groupBy(a, c)(alias1, alias2, alias3)
.orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc)
.select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute)
checkAnalysis(plan, expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
val expected = Project(Seq(a, b), Sort(
Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true,
Seq(SortOrder(grouping_a, Ascending)), true,
Aggregate(Seq(a, b, gid),
Seq(a, b, grouping_a.as("aggOrder")),
Seq(a, b, gid),
Expand(
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
Seq(a, b, c, a, b, 0L)),
Expand All @@ -308,9 +308,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
val expected3 = Project(Seq(a, b), Sort(
Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true,
Seq(SortOrder(gid, Ascending)), true,
Aggregate(Seq(a, b, gid),
Seq(a, b, gid.as("aggOrder")),
Seq(a, b, gid),
Expand(
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
Seq(a, b, c, a, b, 0L)),
Expand Down