Skip to content

Commit

Permalink
[SPARK-10165] [SQL] Await child resolution in ResolveFunctions
Browse files Browse the repository at this point in the history
Currently, we eagerly attempt to resolve functions, even before their children are resolved.  However, this is not valid in cases where we need to know the types of the input arguments (i.e. when resolving Hive UDFs).

As a fix, this PR delays function resolution until the functions children are resolved.  This change also necessitates a change to the way we resolve aggregate expressions that are not in aggregate operators (e.g., in `HAVING` or `ORDER BY` clauses).  Specifically, we can't assume that these misplaced functions will be resolved, allowing us to differentiate aggregate functions from normal functions.  To compensate for this change we now attempt to resolve these unresolved expressions in the context of the aggregate operator, before checking to see if any aggregate expressions are present.

Author: Michael Armbrust <michael@databricks.com>

Closes #8371 from marmbrus/hiveUDFResolution.

(cherry picked from commit 2bf338c)
Signed-off-by: Michael Armbrust <michael@databricks.com>
  • Loading branch information
marmbrus committed Aug 25, 2015
1 parent 8ca8bdd commit 228e429
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Analyzer(
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
ResolveAggregateFunctions ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
Expand Down Expand Up @@ -452,37 +452,6 @@ class Analyzer(
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
if !s.resolved && a.resolved =>
// A small hack to create an object that will allow us to resolve any references that
// refer to named expressions that are present in the grouping expressions.
val groupingRelation = LocalRelation(
grouping.collect { case ne: NamedExpression => ne.toAttribute }
)

// Find sort attributes that are projected away so we can temporarily add them back in.
val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation)

// Find aggregate expressions and evaluate them early, since they can't be evaluated in a
// Sort.
val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
val aliased = Alias(aggOrdering.child, "_aggOrdering")()
(aggOrdering.copy(child = aliased.toAttribute), Some(aliased))

case other => (other, None)
}.unzip

val missing = missingAttr ++ aliasedAggregateList.flatten

if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
Project(a.output,
Sort(withAggsRemoved, global,
Aggregate(grouping, aggs ++ missing, child)))
} else {
s // Nothing we can do here. Return original plan.
}
}

/**
Expand Down Expand Up @@ -515,6 +484,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
Expand Down Expand Up @@ -559,21 +529,79 @@ class Analyzer(
}

/**
* This rule finds expressions in HAVING clause filters that depend on
* unresolved attributes. It pushes these expressions down to the underlying
* aggregates and then projects them away above the filter.
* This rule finds aggregate expressions that are not in an aggregate operator. For example,
* those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the
* underlying aggregate operator and then projected away after the original operator.
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) =>

val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved && !filter.resolved =>

// Try resolving the condition of the filter as though it is in the aggregate clause
val aggregatedCondition =
Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs

Project(aggregate.output,
Filter(resolvedAggregateFilter.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
} else {
filter
}

Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
case sort @ Sort(sortOrder, global,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved && !sort.resolved =>

// Try resolving the ordering as though it is in the aggregate clause.
try {
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions

// Expressions that have an aggregate can be pushed down.
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)

// Attribute references, that are missing from the order but are present in the grouping
// expressions can also be pushed down.
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
val missingAttributes = requiredAttributes -- aggregate.outputSet
val validPushdownAttributes =
missingAttributes.filter(a => grouping.exists(a.semanticEquals))

// If resolution was successful and we see the ordering either has an aggregate in it or
// it is missing something that is projected away by the aggregate, add the ordering
// the original aggregate operator.
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
}
val aggExprsWithOrdering: Seq[NamedExpression] =
resolvedAggregateOrdering ++ originalAggExprs

Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
} else {
sort
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => sort
}
}

protected def containsAggregate(condition: Expression): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ class HiveUDFSuite extends QueryTest {
checkAnswer(
sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"),
Seq(Row("hello world"), Row("hello goodbye")))

checkAnswer(
sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"),
Seq(Row(" hello world"), Row(" hello goodbye")))

sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF")

TestHive.reset()
Expand Down

0 comments on commit 228e429

Please sign in to comment.