diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9dd66492786dd..1ebbfb9a39a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -38,10 +38,10 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} +import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -288,7 +288,6 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: - WrapLateralColumnAliasReference :: ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: @@ -301,8 +300,6 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveGroupByAll :: ResolveOrdinalInOrderByAndGroupBy :: ResolveAggAliasInGroupBy :: - ResolveMissingReferences :: - ResolveOuterReferences :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: @@ -693,9 +690,26 @@ class Analyzer(override val catalogManager: CatalogManager) // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from // groupingExpressions for condition resolving. val aggForResolving = aggregate.copy(groupingExpressions = groupByExprs) + // HACK ALTER! Ideally we should only resolve GROUPING SETS + HAVING when the having condition + // is fully resolved, similar to the rule `ResolveAggregateFunctions`. However, Aggregate + // with GROUPING SETS is marked as unresolved and many analyzer rules can't apply to + // UnresolvedHaving because its child is not resolved. Here we explicitly resolve columns + // and subqueries of UnresolvedHaving so that the rewrite works in most cases. + // TODO: mark Aggregate as resolved even if it has GROUPING SETS. We can expand it at the end + // of the analysis phase. + val colResolved = h.mapExpressions { e => + resolveExpressionByPlanOutput( + resolveColWithAgg(e, aggForResolving), aggForResolving, allowOuter = true) + } + val cond = if (SubqueryExpression.hasSubquery(colResolved.havingCondition)) { + val fake = Project(Alias(colResolved.havingCondition, "fake")() :: Nil, aggregate.child) + ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child + } else { + colResolved.havingCondition + } // Try resolving the condition of the filter as though it is in the aggregate clause val (extraAggExprs, Seq(resolvedHavingCond)) = - ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(h.havingCondition), aggForResolving) + ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) // Push the aggregate expressions into the aggregate (if any). val newChild = constructAggregate(selectedGroupByExprs, groupByExprs, @@ -727,7 +741,10 @@ class Analyzer(override val catalogManager: CatalogManager) if agg.childrenResolved && aggExprs.forall(_.resolved) => tryResolveHavingCondition(h, agg, selectedGroupByExprs, groupByExprs) - case a if !a.childrenResolved => a // be sure all of the children are resolved. + // Make sure all of the children are resolved. + // We can't put this at the beginning, because `Aggregate` with GROUPING SETS is unresolved + // but we need to resolve `UnresolvedHaving` above it. + case a if !a.childrenResolved => a // Ensure group by expressions and aggregate expressions have been resolved. case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, child) @@ -737,7 +754,7 @@ class Analyzer(override val catalogManager: CatalogManager) // We should make sure all expressions in condition have been resolved. case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => val groupingExprs = findGroupingExprs(child) - // The unresolved grouping id will be resolved by ResolveMissingReferences + // The unresolved grouping id will be resolved by ResolveReferences val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) f.copy(condition = newCond) @@ -746,7 +763,7 @@ class Analyzer(override val catalogManager: CatalogManager) if order.exists(hasGroupingFunction) && order.forall(_.resolved) => val groupingExprs = findGroupingExprs(child) val gid = VirtualColumn.groupingIdAttribute - // The unresolved grouping id will be resolved by ResolveMissingReferences + // The unresolved grouping id will be resolved by ResolveReferences val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) s.copy(order = newOrder) } @@ -1398,8 +1415,21 @@ class Analyzer(override val catalogManager: CatalogManager) } /** - * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from - * a logical plan node's children. + * Resolves [[UnresolvedAttribute]]s with the following precedence: + * 1. Resolves it to [[AttributeReference]] with the output of the children plans. This includes + * metadata columns as well. + * 2. If the plan is Project/Aggregate, resolves it to lateral column alias, which is the alias + * defined previously in the SELECT list. + * 3. If the plan is UnresolvedHaving/Filter/Sort + Aggregate, resolves it to + * [[TempResolvedColumn]] with the output of Aggregate's child plan. This is to allow + * UnresolvedHaving/Filter/Sort to host grouping expressions and aggregate functions, which + * can be pushed down to the Aggregate later. + * 4. If the plan is Sort/Filter/RepartitionByExpression, resolves it to [[AttributeReference]] + * with the output of a descendant plan node. Spark will propagate the missing attributes from + * the descendant plan node to the Sort/Filter/RepartitionByExpression node. This is to allow + * users to filter/order/repartition by columns that are not in the SELECT clause, which is + * widely supported in other SQL dialects. + * 5. Resolves it to [[OuterReference]] with the outer plan if this is a subquery plan. */ object ResolveReferences extends Rule[LogicalPlan] { @@ -1420,8 +1450,8 @@ class Analyzer(override val catalogManager: CatalogManager) } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - AlwaysProcess.fn, ruleId) { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + // Wait for other rules to resolve child plans first case p: LogicalPlan if !p.childrenResolved => p // Wait for the rule `DeduplicateRelations` to resolve conflicting attrs first. @@ -1477,19 +1507,12 @@ class Analyzer(override val catalogManager: CatalogManager) } u.withNewChildren(newChildren) - // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on its descendants - case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => - val newOrdering = - ordering.map(order => resolveExpressionByPlanOutput(order, child).asInstanceOf[SortOrder]) - Sort(newOrdering, global, child) - // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g case g @ Generate(generator, join, outer, qualifier, output, child) => - val newG = resolveExpressionByPlanOutput(generator, child, throws = true) + val newG = resolveExpressionByPlanOutput(generator, child, throws = true, allowOuter = true) if (newG.fastEquals(generator)) { g } else { @@ -1515,14 +1538,25 @@ class Analyzer(override val catalogManager: CatalogManager) } val resolvedGroupingExprs = a.groupingExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve)) + .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = true)) .map(trimTopLevelGetStructFieldAlias) - val resolvedAggExprs = a.aggregateExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve)) - .map(_.asInstanceOf[NamedExpression]) - - a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child) + val resolvedAggExprsNoOuter = a.aggregateExpressions + .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) + // Aggregate supports Lateral column alias, which has higher priority than outer reference. + val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) + val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) + .map(_.asInstanceOf[NamedExpression]) + a.copy(resolvedGroupingExprs, resolvedAggExprsWithOuter, a.child) + + // Special case for Project as it supports lateral column alias. + case p: Project => + val resolvedNoOuter = p.projectList + .map(resolveExpressionByPlanChildren(_, p, allowOuter = false)) + // Lateral column alias has higher priority than outer reference. + val resolvedWithLCA = resolveLateralColumnAlias(resolvedNoOuter) + val resolvedWithOuter = resolvedWithLCA.map(resolveOuterRef) + p.copy(projectList = resolvedWithOuter.map(_.asInstanceOf[NamedExpression])) case o: OverwriteByExpression if o.table.resolved => // The delete condition of `OverwriteByExpression` will be passed to the table @@ -1606,12 +1640,125 @@ class Analyzer(override val catalogManager: CatalogManager) notMatchedBySourceActions = newNotMatchedBySourceActions) } - // Skip the having clause here, this will be handled in ResolveAggregateFunctions. - case h: UnresolvedHaving => h + // UnresolvedHaving can host grouping expressions and aggregate functions. We should resolve + // columns with `agg.output` and the rule `ResolveAggregateFunctions` will push them down to + // Aggregate later. + case u @ UnresolvedHaving(cond, agg: Aggregate) if !cond.resolved => + u.mapExpressions { e => + // Columns in HAVING should be resolved with `agg.child.output` first, to follow the SQL + // standard. See more details in SPARK-31519. + val resolvedWithAgg = resolveColWithAgg(e, agg) + resolveExpressionByPlanChildren(resolvedWithAgg, u, allowOuter = true) + } + + // RepartitionByExpression can host missing attributes that are from a descendant node. + // For example, `spark.table("t").select($"a").repartition($"b")`. We can resolve `b` with + // table `t` even if there is a Project node between the table scan node and Sort node. + // We also need to propagate the missing attributes from the descendant node to the current + // node, and project them way at the end via an extra Project. + case r @ RepartitionByExpression(partitionExprs, child, _) + if !r.resolved || r.missingInput.nonEmpty => + val resolvedNoOuter = partitionExprs.map(resolveExpressionByPlanChildren(_, r)) + val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(resolvedNoOuter, child) + // Outer reference has lower priority than this. See the doc of `ResolveReferences`. + val finalPartitionExprs = newPartitionExprs.map(resolveOuterRef) + if (child.output == newChild.output) { + r.copy(finalPartitionExprs, newChild) + } else { + Project(child.output, r.copy(finalPartitionExprs, newChild)) + } + + // Filter can host both grouping expressions/aggregate functions and missing attributes. + // The grouping expressions/aggregate functions resolution takes precedence over missing + // attributes. See the classdoc of `ResolveReferences` for details. + case f @ Filter(cond, child) if !cond.resolved || f.missingInput.nonEmpty => + val resolvedNoOuter = resolveExpressionByPlanChildren(cond, f) + val resolvedWithAgg = resolveColWithAgg(resolvedNoOuter, child) + val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child) + // Outer reference has lowermost priority. See the doc of `ResolveReferences`. + val finalCond = resolveOuterRef(newCond.head) + if (child.output == newChild.output) { + f.copy(condition = finalCond) + } else { + // Add missing attributes and then project them away. + val newFilter = Filter(finalCond, newChild) + Project(child.output, newFilter) + } + + // Same as Filter, Sort can host both grouping expressions/aggregate functions and missing + // attributes as well. + case s @ Sort(orders, _, child) if !s.resolved || s.missingInput.nonEmpty => + val resolvedNoOuter = orders.map(resolveExpressionByPlanOutput(_, child)) + val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, child)) + val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, child) + // Outer reference has lowermost priority. See the doc of `ResolveReferences`. + val ordering = newOrder.map(e => resolveOuterRef(e).asInstanceOf[SortOrder]) + if (child.output == newChild.output) { + s.copy(order = ordering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = ordering, child = newChild) + Project(child.output, newSort) + } case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") - q.mapExpressions(resolveExpressionByPlanChildren(_, q)) + q.mapExpressions(resolveExpressionByPlanChildren(_, q, allowOuter = true)) + } + + /** + * This method tries to resolve expressions and find missing attributes recursively. + * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes + * or resolved attributes which are missing from child output. This method tries to find the + * missing attributes and add them into the projection. + */ + private def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { + (exprs, plan) + } else { + plan match { + case p: Project => + // Resolving expressions against current plan. + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) + // Recursively resolving expressions on the child of current plan. + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpressionByPlanOutput(_, other)), other) + } + } } private object MergeResolvePolicy extends Enumeration { @@ -1687,7 +1834,7 @@ class Analyzer(override val catalogManager: CatalogManager) // Only Project and Aggregate can host star expressions. case u @ (_: Project | _: Aggregate) => Try(s.expand(u.children.head, resolver)) match { - case Success(expanded) => expanded.map(wrapOuterReference(_)) + case Success(expanded) => expanded.map(wrapOuterReference) case Failure(_) => throw e } // Do not use the outer plan to resolve the star expression @@ -1776,141 +1923,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * The first phase to resolve lateral column alias. See comments in - * [[ResolveLateralColumnAliasReference]] for more detailed explanation. - */ - object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { - import ResolveLateralColumnAliasReference.AliasEntry - - private def insertIntoAliasMap( - a: Alias, - idx: Int, - aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - - /** - * Use the given lateral alias to resolve the unresolved attribute with the name parts. - * - * Construct a dummy plan with the given lateral alias as project list, use the output of the - * plan to resolve. - * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. - */ - private def resolveByLateralAlias( - nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - val resolvedAttr = resolveExpressionByPlanOutput( - expr = UnresolvedAttribute(nameParts), - plan = LocalRelation(Seq(lateralAlias.toAttribute)), - throws = false - ).asInstanceOf[NamedExpression] - if (resolvedAttr.resolved) { - Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) - } else { - None - } - } - - /** - * Recognize all the attributes in the given expression that reference lateral column aliases - * by looking up the alias map. Resolve these attributes and replace by wrapping with - * [[LateralColumnAliasReference]]. - * - * @param currentPlan Because lateral alias has lower resolution priority than table columns, - * the current plan is needed to first try resolving the attribute by its - * children - */ - private def wrapLCARef( - e: NamedExpression, - currentPlan: LogicalPlan, - aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - // The lateral alias can be a struct and have nested field, need to construct - // a dummy plan to resolve the expression - resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) - case _ => u - } - case o: OuterReference - if aliasMap.contains( - o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) - .map(_.head) - .getOrElse(o.name)) => - // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o - .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) - .getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAliasError(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) - case _ => o - } - }.asInstanceOf[NamedExpression] - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { - plan - } else { - plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { - case p @ Project(projectList, _) if p.childrenResolved - && !ResolveReferences.containsStar(projectList) - && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] - // Insert the LCA-resolved alias instead of the unresolved one into map. If it is - // resolved, it can be referenced as LCA by later expressions (chaining). - // Unresolved Alias is also added to the map to perform ambiguous name check, but - // only resolved alias can be LCA. - aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) - lcaWrapped - case (e, _) => - wrapLCARef(e, p, aliasMap) - } - p.copy(projectList = newProjectList) - - // Implementation notes: - // In Aggregate, introducing and wrapping this resolved leaf expression - // LateralColumnAliasReference is especially needed because it needs an accurate condition - // to trigger adding a Project above and extracting and pushing down aggregate functions - // or grouping expressions. Such operation can only be done once. With this - // LateralColumnAliasReference, that condition can simply be when the whole Aggregate is - // resolved. Otherwise, it can't tell if all aggregate functions are created and - // resolved so that it can start the extraction, because the lateral alias reference is - // unresolved and can be the argument to functions, blocking the resolution of functions. - case agg @ Aggregate(_, aggExprs, _) if agg.childrenResolved - && !ResolveReferences.containsStar(aggExprs) - && aggExprs.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - val newAggExprs = aggExprs.zipWithIndex.map { - case (a: Alias, idx) => - val lcaWrapped = wrapLCARef(a, agg, aliasMap).asInstanceOf[Alias] - aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) - lcaWrapped - case (e, _) => - wrapLCARef(e, agg, aliasMap) - } - agg.copy(aggregateExpressions = newAggExprs) - } - } - } - } - private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } @@ -1953,10 +1965,11 @@ class Analyzer(override val catalogManager: CatalogManager) expr: Expression, resolveColumnByName: Seq[String] => Option[Expression], getAttrCandidates: () => Seq[Attribute], - throws: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + throws: Boolean, + allowOuter: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { if (e.resolved) return e - e match { + val resolved = e match { case f: LambdaFunction if !f.bound => f case GetColumnByOrdinal(ordinal, _) => @@ -1988,22 +2001,33 @@ class Analyzer(override val catalogManager: CatalogManager) logDebug(s"Resolving $u to $result") result + // Re-resolves `TempResolvedColumn` if it has tried to be resolved with Aggregate + // but failed. If we still can't resolve it, we should keep it as `TempResolvedColumn`, + // so that it won't become a fresh `TempResolvedColumn` again. + case t: TempResolvedColumn if t.hasTried => withPosition(t) { + innerResolve(UnresolvedAttribute(t.nameParts), isTopLevel) match { + case _: UnresolvedAttribute => t + case other => other + } + } + case u @ UnresolvedExtractValue(child, fieldName) => val newChild = innerResolve(child, isTopLevel = false) if (newChild.resolved) { - withOrigin(u.origin) { - ExtractValue(newChild, fieldName, resolver) - } + ExtractValue(newChild, fieldName, resolver) } else { u.copy(child = newChild) } case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) } + resolved.copyTagsFrom(e) + resolved } try { - innerResolve(expr, isTopLevel = true) + val resolved = innerResolve(expr, isTopLevel = true) + if (allowOuter) resolveOuterRef(resolved) else resolved } catch { case ae: AnalysisException if !throws => logDebug(ae.getMessage) @@ -2011,6 +2035,117 @@ class Analyzer(override val catalogManager: CatalogManager) } } + // Resolves `UnresolvedAttribute` to `OuterReference`. + private def resolveOuterRef(e: Expression): Expression = { + val outerPlan = AnalysisContext.get.outerPlan + if (outerPlan.isEmpty) return e + + def resolve(nameParts: Seq[String]): Option[Expression] = try { + outerPlan.get match { + // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. + // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will + // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. + case u @ UnresolvedHaving(_, agg: Aggregate) => + agg.resolveChildren(nameParts, resolver).orElse(u.resolveChildren(nameParts, resolver)) + .map(wrapOuterReference) + case other => + other.resolveChildren(nameParts, resolver).map(wrapOuterReference) + } + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + None + } + + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + case u: UnresolvedAttribute => + resolve(u.nameParts).getOrElse(u) + // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with + // Aggregate but failed. + case t: TempResolvedColumn if t.hasTried => + resolve(t.nameParts).getOrElse(t) + } + } + + // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an + // `Aggregate`. If `TempResolvedColumn` doesn't end up as aggregate function input or grouping + // column, we will undo the column resolution later to avoid confusing error message. E,g,, if + // a table `t` has columns `c1` and `c2`, for query `SELECT ... FROM t GROUP BY c1 HAVING c2 = 0`, + // even though we can resolve column `c2` here, we should undo it and fail with + // "Column c2 not found". + private def resolveColWithAgg(e: Expression, plan: LogicalPlan): Expression = plan match { + case agg: Aggregate => + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute => + try { + agg.child.resolve(u.nameParts, resolver).map({ + case a: Alias => TempResolvedColumn(a.child, u.nameParts) + case o => TempResolvedColumn(o, u.nameParts) + }).getOrElse(u) + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + u + } + } + case _ => e + } + + private def resolveLateralColumnAlias(selectList: Seq[Expression]): Seq[Expression] = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) return selectList + + // A mapping from lower-cased alias name to either the Alias itself, or the count of aliases + // that have the same lower-cased name. If the count is larger than 1, we won't use it to + // resolve lateral column aliases. + val aliasMap = mutable.HashMap.empty[String, Either[Alias, Int]] + + def resolve(e: Expression): Expression = { + e.transformWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, LATERAL_COLUMN_ALIAS_REFERENCE)) { + case u: UnresolvedAttribute => + // Lateral column alias does not have qualifiers. We always use the first name part to + // look up lateral column aliases. + val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT) + aliasMap.get(lowerCasedName).map { + case scala.util.Left(alias) => + if (alias.resolved) { + val resolvedAttr = resolveExpressionByPlanOutput( + u, LocalRelation(Seq(alias.toAttribute)), throws = true + ).asInstanceOf[NamedExpression] + assert(resolvedAttr.resolved) + LateralColumnAliasReference(resolvedAttr, u.nameParts, alias.toAttribute) + } else { + // Still returns a `LateralColumnAliasReference` even if the lateral column alias + // is not resolved yet. This is to make sure we won't mistakenly resolve it to + // outer references. + LateralColumnAliasReference(u, u.nameParts, alias.toAttribute) + } + case scala.util.Right(count) => + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, count) + }.getOrElse(u) + + case LateralColumnAliasReference(u: UnresolvedAttribute, _, _) => + resolve(u) + } + } + + selectList.map { + case a: Alias => + val result = resolve(a) + val lowerCasedName = a.name.toLowerCase(Locale.ROOT) + aliasMap.get(lowerCasedName) match { + case Some(scala.util.Left(_)) => + aliasMap(lowerCasedName) = scala.util.Right(2) + case Some(scala.util.Right(count)) => + aliasMap(lowerCasedName) = scala.util.Right(count + 1) + case None => + aliasMap += lowerCasedName -> scala.util.Left(a) + } + result + case other => resolve(other) + } + } + /** * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the * input plan's output attributes. In order to resolve the nested fields correctly, this function @@ -2026,14 +2161,16 @@ class Analyzer(override val catalogManager: CatalogManager) def resolveExpressionByPlanOutput( expr: Expression, plan: LogicalPlan, - throws: Boolean = false): Expression = { + throws: Boolean = false, + allowOuter: Boolean = false): Expression = { resolveExpression( expr, resolveColumnByName = nameParts => { plan.resolve(nameParts, resolver) }, getAttrCandidates = () => plan.output, - throws = throws) + throws = throws, + allowOuter = allowOuter) } /** @@ -2046,7 +2183,8 @@ class Analyzer(override val catalogManager: CatalogManager) */ def resolveExpressionByPlanChildren( e: Expression, - q: LogicalPlan): Expression = { + q: LogicalPlan, + allowOuter: Boolean = false): Expression = { resolveExpression( e, resolveColumnByName = nameParts => { @@ -2056,7 +2194,8 @@ class Analyzer(override val catalogManager: CatalogManager) assert(q.children.length == 1) q.children.head.output }, - throws = true) + throws = true, + allowOuter = allowOuter) } /** @@ -2158,152 +2297,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT - * clause. This rule detects such queries and adds the required attributes to the original - * projection, so that they will be available during sorting. Another projection is added to - * remove these attributes after sorting. - * - * The HAVING clause could also used a grouping columns that is not presented in the SELECT. - */ - object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(SORT, FILTER, REPARTITION_OPERATION), ruleId) { - // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, child: Aggregate) => sa - - case s @ Sort(order, _, child) - if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => - val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) - val ordering = newOrder.map(_.asInstanceOf[SortOrder]) - if (child.output == newChild.output) { - s.copy(order = ordering) - } else { - // Add missing attributes and then project them away. - val newSort = s.copy(order = ordering, child = newChild) - Project(child.output, newSort) - } - - case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => - val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) - if (child.output == newChild.output) { - f.copy(condition = newCond.head) - } else { - // Add missing attributes and then project them away. - val newFilter = Filter(newCond.head, newChild) - Project(child.output, newFilter) - } - - case r @ RepartitionByExpression(partitionExprs, child, _) - if (!r.resolved || r.missingInput.nonEmpty) && child.resolved => - val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(partitionExprs, child) - if (child.output == newChild.output) { - r.copy(newPartitionExprs, newChild) - } else { - Project(child.output, r.copy(newPartitionExprs, newChild)) - } - } - - /** - * This method tries to resolve expressions and find missing attributes recursively. - * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes - * or resolved attributes which are missing from child output. This method tries to find the - * missing attributes and add them into the projection. - */ - private def resolveExprsAndAddMissingAttrs( - exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - // Missing attributes can be unresolved attributes or resolved attributes which are not in - // the output attributes of the plan. - if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { - (exprs, plan) - } else { - plan match { - case p: Project => - // Resolving expressions against current plan. - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) - // Recursively resolving expressions on the child of current plan. - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - // If some attributes used by expressions are resolvable only on the rewritten child - // plan, we need to add them into original projection. - val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) - (newExprs, Project(p.projectList ++ missingAttrs, newChild)) - - case a @ Aggregate(groupExprs, aggExprs, child) => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) - if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { - // All the missing attributes are grouping expressions, valid case. - (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) - } else { - // Need to add non-grouping attributes, invalid case. - (exprs, a) - } - - case g: Generate => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) - (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) - - // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes - // via its children. - case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) - (newExprs, u.withNewChildren(Seq(newChild))) - - // For other operators, we can't recursively resolve and add attributes via its children. - case other => - (exprs.map(resolveExpressionByPlanOutput(_, other)), other) - } - } - } - } - - /** - * Resolves `UnresolvedAttribute` to `OuterReference` if we are resolving subquery plans (when - * `AnalysisContext.get.outerPlan` is set). - */ - object ResolveOuterReferences extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Only apply this rule if we are resolving subquery plans. - if (AnalysisContext.get.outerPlan.isEmpty) return plan - - // We must run these 3 rules first, as they also resolve `UnresolvedAttribute` and have - // higher priority than outer reference resolution. - val prepared = ResolveAggregateFunctions(ResolveMissingReferences(ResolveReferences(plan))) - prepared.resolveOperatorsDownWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - // Handle `Generate` specially here, because `Generate.generatorOutput` starts with - // `UnresolvedAttribute` but we should never resolve it to outer references. It's a bit - // hacky that `Generate` uses `UnresolvedAttribute` to store the generator column names, - // we should clean it up later. - case g: Generate if g.childrenResolved && !g.resolved => - val newGenerator = g.generator.transformWithPruning( - _.containsPattern(UNRESOLVED_ATTRIBUTE))(resolveOuterReference) - val resolved = g.copy(generator = newGenerator.asInstanceOf[Generator]) - resolved.copyTagsFrom(g) - resolved - case q: LogicalPlan if q.childrenResolved && !q.resolved => - q.transformExpressionsWithPruning( - _.containsPattern(UNRESOLVED_ATTRIBUTE))(resolveOuterReference) - } - } - - private val resolveOuterReference: PartialFunction[Expression, Expression] = { - case u @ UnresolvedAttribute(nameParts) => withPosition(u) { - try { - AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match { - case Some(resolved) => wrapOuterReference(resolved, Some(nameParts)) - case None => u - } - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - u - } - } - } - } /** * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the @@ -2819,34 +2812,39 @@ class Analyzer(override val catalogManager: CatalogManager) * This rule finds aggregate expressions that are not in an aggregate operator. For example, * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the * underlying aggregate operator and then projected away after the original operator. + * + * We need to make sure the expressions all fully resolved before looking for aggregate functions + * and group by expressions from them. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(AGGREGATE), ruleId) { - // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly - // resolve the having condition expression, here we skip resolving it in ResolveReferences - // and transform it to Filter after aggregate is resolved. Basically columns in HAVING should - // be resolved with `agg.child.output` first. See more details in SPARK-31519. - case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => + case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved && cond.resolved => resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { - Filter(newExprs.head, newChild) + val newCond = newExprs.head + if (newCond.resolved) { + Filter(newCond, newChild) + } else { + // The condition can be unresolved after the resolution, as we may mark + // `TempResolvedColumn` as unresolved if it's not aggregate function inputs or grouping + // expressions. We should remain `UnresolvedHaving` as the rule `ResolveReferences` can + // re-resolve `TempResolvedColumn` and `UnresolvedHaving` has a special column + // resolution order. + UnresolvedHaving(newCond, newChild) + } }) - case Filter(cond, agg: Aggregate) if agg.resolved => - // We should resolve the references normally based on child (agg.output) first. - val maybeResolved = resolveExpressionByPlanOutput(cond, agg) - resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs, newChild) => { + case Filter(cond, agg: Aggregate) if agg.resolved && cond.resolved => + resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { Filter(newExprs.head, newChild) }) - case Sort(sortOrder, global, agg: Aggregate) if agg.resolved => - // We should resolve the references normally based on child (agg.output) first. - val maybeResolved = sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg)) - resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild) => { - val newSortOrder = sortOrder.zip(newExprs).map { + case s @ Sort(_, _, agg: Aggregate) if agg.resolved && s.order.forall(_.resolved) => + resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => { + val newSortOrder = s.order.zip(newExprs).map { case (sortOrder, expr) => sortOrder.copy(child = expr) } - Sort(newSortOrder, global, newChild) + s.copy(order = newSortOrder, child = newChild) }) } @@ -2859,45 +2857,12 @@ class Analyzer(override val catalogManager: CatalogManager) def resolveExprsWithAggregate( exprs: Seq[Expression], agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { - def resolveCol(input: Expression): Expression = { - input.transform { - case u: UnresolvedAttribute => - try { - // Resolve the column and wrap it with `TempResolvedColumn`. If the resolved column - // doesn't end up with as aggregate function input or grouping column, we should - // undo the column resolution to avoid confusing error message. For example, if - // a table `t` has two columns `c1` and `c2`, for query `SELECT ... FROM t - // GROUP BY c1 HAVING c2 = 0`, even though we can resolve column `c2` here, we - // should undo it later and fail with "Column c2 not found". - agg.child.resolve(u.nameParts, resolver).map({ - case a: Alias => TempResolvedColumn(a.child, u.nameParts) - case o => TempResolvedColumn(o, u.nameParts) - }).getOrElse(u) - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - u - } - } - } - - def resolveSubQuery(input: Expression): Expression = { - if (SubqueryExpression.hasSubquery(input)) { - val fake = Project(Alias(input, "fake")() :: Nil, agg.child) - ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child - } else { - input - } - } - val extraAggExprs = ArrayBuffer.empty[NamedExpression] val transformed = exprs.map { e => - // Try resolving the expression as though it is in the aggregate clause. - val maybeResolved = resolveSubQuery(resolveCol(e)) - if (!maybeResolved.resolved) { - maybeResolved + if (!e.resolved) { + e } else { - buildAggExprList(maybeResolved, agg, extraAggExprs) + buildAggExprList(e, agg, extraAggExprs) } } (extraAggExprs.toSeq, transformed) @@ -2918,12 +2883,12 @@ class Analyzer(override val catalogManager: CatalogManager) } else { expr match { case ae: AggregateExpression => - val cleaned = RemoveTempResolvedColumn.trimTempResolvedColumn(ae) + val cleaned = trimTempResolvedColumn(ae) val alias = Alias(cleaned, cleaned.toString)() aggExprList += alias alias.toAttribute case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) => - RemoveTempResolvedColumn.trimTempResolvedColumn(grouping) match { + trimTempResolvedColumn(grouping) match { case ne: NamedExpression => aggExprList += ne ne.toAttribute @@ -2933,15 +2898,35 @@ class Analyzer(override val catalogManager: CatalogManager) alias.toAttribute } case t: TempResolvedColumn => - // Undo the resolution as this column is neither inside aggregate functions nor a - // grouping column. It shouldn't be resolved with `agg.child.output`. - RemoveTempResolvedColumn.restoreTempResolvedColumn(t) + if (t.child.isInstanceOf[Attribute]) { + // This column is neither inside aggregate functions nor a grouping column. It + // shouldn't be resolved with `agg.child.output`. Mark it as "hasTried", so that it + // can be re-resolved later or go back to `UnresolvedAttribute` at the end. + withOrigin(t.origin)(t.copy(hasTried = true)) + } else { + // This is a nested column, we still have a chance to match grouping expressions with + // the the top-levle column. Here we wrap the underlying `Attribute` with + // `TempResolvedColumn` and try again. + val childWithTempCol = t.child.transformUp { + case a: Attribute => TempResolvedColumn(a, Seq(a.name)) + } + val newChild = buildAggExprList(childWithTempCol, agg, aggExprList) + if (newChild.containsPattern(TEMP_RESOLVED_COLUMN)) { + withOrigin(t.origin)(t.copy(hasTried = true)) + } else { + newChild + } + } case other => other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList))) } } } + private def trimTempResolvedColumn(input: Expression): Expression = input.transform { + case t: TempResolvedColumn => t.child + } + def resolveOperatorWithAggregate( exprs: Seq[Expression], agg: Aggregate, @@ -4255,42 +4240,32 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } /** - * The rule `ResolveAggregationFunctions` in the main resolution batch creates - * [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved - * column with `agg.child`. When filter conditions or sort expressions are resolved, - * `ResolveAggregationFunctions` will replace [[TempResolvedColumn]], to [[AttributeReference]] if - * it's inside aggregate functions or group expressions, or to [[UnresolvedAttribute]] otherwise, - * hoping other rules can resolve it. + * The rule `ResolveReferences` in the main resolution batch creates [[TempResolvedColumn]] in + * UnresolvedHaving/Filter/Sort to hold the temporarily resolved column with `agg.child`. + * + * If the expression hosting [[TempResolvedColumn]] is fully resolved, the rule + * `ResolveAggregationFunctions` will + * - Replace [[TempResolvedColumn]] with [[AttributeReference]] if it's inside aggregate functions + * or grouping expressions. + * - Mark [[TempResolvedColumn]] as `hasTried` if not inside aggregate functions or grouping + * expressions, hoping other rules can re-resolve it. + * `ResolveReferences` will re-resolve [[TempResolvedColumn]] if `hasTried` is true, and keep it + * unchanged if the resolution fails. We should turn it back to [[UnresolvedAttribute]] so that the + * analyzer can report missing column error later. * - * This rule runs after the main resolution batch, and can still hit [[TempResolvedColumn]] if - * filter conditions or sort expressions are not resolved. When this happens, there is no point to - * turn [[TempResolvedColumn]] to [[UnresolvedAttribute]], as we can't resolve the column - * differently, and query will fail. This rule strips all [[TempResolvedColumn]]s in Filter/Sort and - * turns them to [[AttributeReference]] so that the error message can tell users why the filter - * conditions or sort expressions were not resolved. + * If the expression hosting [[TempResolvedColumn]] is not resolved, [[TempResolvedColumn]] will + * remain with `hasTried` as false. We should strip [[TempResolvedColumn]], so that users can see + * the reason why the expression is not resolved, e.g. type mismatch. */ object RemoveTempResolvedColumn extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUp { - case f @ Filter(cond, agg: Aggregate) if agg.resolved => - withOrigin(f.origin)(f.copy(condition = trimTempResolvedColumn(cond))) - case s @ Sort(sortOrder, _, agg: Aggregate) if agg.resolved => - val newSortOrder = sortOrder.map { order => - trimTempResolvedColumn(order).asInstanceOf[SortOrder] + plan.resolveExpressionsWithPruning(_.containsPattern(TEMP_RESOLVED_COLUMN)) { + case t: TempResolvedColumn => + if (t.hasTried) { + UnresolvedAttribute(t.nameParts) + } else { + t.child } - withOrigin(s.origin)(s.copy(order = newSortOrder)) - case other => other.transformExpressionsUp { - // This should not happen. We restore TempResolvedColumn to UnresolvedAttribute to be safe. - case t: TempResolvedColumn => restoreTempResolvedColumn(t) - } } } - - def trimTempResolvedColumn(input: Expression): Expression = input.transform { - case t: TempResolvedColumn => t.child - } - - def restoreTempResolvedColumn(t: TempResolvedColumn): Expression = { - CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts)) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 2fad1faec3f52..089d2bec2172d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -82,26 +81,10 @@ import org.apache.spark.sql.internal.SQLConf * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27, * dept#14] * +- Child [dept#14,name#15,salary#16,bonus#17] - * - * - * The name resolution priority: - * local table column > local lateral column alias > outer reference - * - * Because lateral column alias has higher resolution priority than outer reference, it will try - * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an - * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with - * [[LateralColumnAliasReference]]. */ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case class AliasEntry(alias: Alias, index: Int) - /** - * A tag to store the nameParts from the original unresolved attribute. - * It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back - * to [[LateralColumnAliasReference]]. - */ - val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") - private def assignAlias(expr: Expression): NamedExpression = { expr match { case ne: NamedExpression => ne @@ -112,6 +95,11 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan + } else if (plan.containsPattern(TEMP_RESOLVED_COLUMN)) { + // We should not change the plan if `TempResolvedColumn` is present in the query plan. It + // needs certain plan shape to get resolved, such as Filter/Sort + Aggregate. LCA resolution + // may break the plan shape, like adding Project above Aggregate. + plan } else { // phase 2: unwrap plan.resolveOperatorsUpWithPruning( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 687bf4f775edd..e194aa375314e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -661,13 +661,26 @@ case object UnresolvedSeed extends LeafExpression with Unevaluable { /** * An intermediate expression to hold a resolved (nested) column. Some rules may need to undo the - * column resolution and use this expression to keep the original column name. + * column resolution and use this expression to keep the original column name, or redo the column + * resolution with a different priority if the analyzer has tried to resolve it with the default + * priority before but failed (i.e. `hasTried` is true). */ -case class TempResolvedColumn(child: Expression, nameParts: Seq[String]) extends UnaryExpression +case class TempResolvedColumn( + child: Expression, + nameParts: Seq[String], + hasTried: Boolean = false) extends UnaryExpression with Unevaluable { + // If it has been tried to be resolved but failed, mark it as unresolved so that other rules can + // try to resolve it again. + override lazy val resolved = child.resolved && !hasTried override lazy val canonicalized = child.canonicalized override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable + // `TempResolvedColumn` is logically a leaf node. We should not count it as a missing reference + // when resolving Filter/Sort/RepartitionByExpression. However, we should not make it a real + // leaf node, as rules that update expr IDs should update `TempResolvedColumn.child` as well. + override def references: AttributeSet = AttributeSet.empty override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) - override def sql: String = child.sql + final override val nodePatterns: Seq[TreePattern] = Seq(TEMP_RESOLVED_COLUMN) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9670b7997c1a0..a2d167a77f3dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -433,26 +433,27 @@ case class OuterReference(e: NamedExpression) /** * A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the - * reference to a lateral column alias. + * reference to a lateral column alias. It will be restored back to [[UnresolvedAttribute]] if + * the lateral column alias can't be resolved, or become a normal resolved column in the rewritten + * plan after lateral column resolution. There should be no [[LateralColumnAliasReference]] beyond + * analyzer: if the plan passes all analysis check, then all [[LateralColumnAliasReference]] should + * already be removed. * - * This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]]. - * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all - * analysis check, then all [[LateralColumnAliasReference]] should already be removed. - * - * @param ne the resolved [[NamedExpression]] by lateral column alias - * @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back + * @param ne the [[NamedExpression]] produced by column resolution. Can be [[UnresolvedAttribute]] + * if the referenced lateral column alias is not resolved yet. + * @param nameParts the name parts of the original [[UnresolvedAttribute]]. Used to restore back * to [[UnresolvedAttribute]] when needed * @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping - * and resolving LateralColumnAliasReference + * and resolving lateral column aliases and rewriting the query plan. */ case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute) extends LeafExpression with NamedExpression with Unevaluable { - assert(ne.resolved) - override def name: String = - nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + assert(ne.resolved || ne.isInstanceOf[UnresolvedAttribute]) + override def name: String = ne.name override def exprId: ExprId = ne.exprId override def qualifier: Seq[String] = ne.qualifier override def toAttribute: Attribute = ne.toAttribute + override lazy val resolved = ne.resolved override def newInstance(): NamedExpression = LateralColumnAliasReference(ne.newInstance(), nameParts, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index b510893f370e5..e7384dac2d53e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan} @@ -159,12 +158,8 @@ object SubExprUtils extends PredicateHelper { /** * Wrap attributes in the expression with [[OuterReference]]s. */ - def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { - e.transform { case a: Attribute => - val o = OuterReference(a) - nameParts.map(o.setTagValue(NAME_PARTS_FROM_UNRESOLVED_ATTR, _)) - o - }.asInstanceOf[E] + def wrapOuterReference[E <: Expression](e: E): E = { + e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 4cf774b036277..4be3f97dca8fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -58,7 +58,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics" :: "org.apache.spark.sql.catalyst.analysis.ResolveHigherOrderFunctions" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveInsertInto" :: - "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveMissingReferences" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNewInstance" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveOrdinalInOrderByAndGroupBy" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 9eb8ce21ef244..5dc60acea3c46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -122,6 +122,7 @@ object TreePattern extends Enumeration { val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value val UNRESOLVED_WITH: Value = Value + val TEMP_RESOLVED_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 89898b89a38e4..14abd4dbed0a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -547,8 +547,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { test("Lateral alias of a complex type") { // test both Project and Aggregate - // TODO(anchovyu): re-enable aggregate tests when fixed the having issue - val querySuffixes = Seq(""/* , s"FROM $testTable GROUP BY dept HAVING dept = 6" */) + val querySuffixes = Seq("", s"FROM $testTable GROUP BY dept HAVING dept = 6") querySuffixes.foreach { querySuffix => checkAnswer( sql(s"SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1 $querySuffix"),