Skip to content

Commit

Permalink
[SPARK-41405][SQL] Centralize the column resolution logic
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR is a major refactor of how Spark resolves columns. Today, the column resolution logic is placed in several rules, which makes it hard to understand. It's also very fragile to maintain the resolution precedence, as you have to carefully deal with the interactions between these rules.

This PR centralizes the column resolution logic into a single rule: the existing `ResolveReferences` rule, so that we no longer need to worry about the interactions between multiple rules. The detailed resolution precedence is also documented.

### Why are the changes needed?

code cleanup

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #38888 from cloud-fan/col.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Jan 3, 2023
1 parent f0d9692 commit 3c40be2
Show file tree
Hide file tree
Showing 8 changed files with 424 additions and 453 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -82,26 +81,10 @@ import org.apache.spark.sql.internal.SQLConf
* +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,
* dept#14]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
*
* The name resolution priority:
* local table column > local lateral column alias > outer reference
*
* Because lateral column alias has higher resolution priority than outer reference, it will try
* to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an
* [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with
* [[LateralColumnAliasReference]].
*/
object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
case class AliasEntry(alias: Alias, index: Int)

/**
* A tag to store the nameParts from the original unresolved attribute.
* It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back
* to [[LateralColumnAliasReference]].
*/
val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr")

private def assignAlias(expr: Expression): NamedExpression = {
expr match {
case ne: NamedExpression => ne
Expand All @@ -112,6 +95,11 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else if (plan.containsPattern(TEMP_RESOLVED_COLUMN)) {
// We should not change the plan if `TempResolvedColumn` is present in the query plan. It
// needs certain plan shape to get resolved, such as Filter/Sort + Aggregate. LCA resolution
// may break the plan shape, like adding Project above Aggregate.
plan
} else {
// phase 2: unwrap
plan.resolveOperatorsUpWithPruning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,13 +661,26 @@ case object UnresolvedSeed extends LeafExpression with Unevaluable {

/**
* An intermediate expression to hold a resolved (nested) column. Some rules may need to undo the
* column resolution and use this expression to keep the original column name.
* column resolution and use this expression to keep the original column name, or redo the column
* resolution with a different priority if the analyzer has tried to resolve it with the default
* priority before but failed (i.e. `hasTried` is true).
*/
case class TempResolvedColumn(child: Expression, nameParts: Seq[String]) extends UnaryExpression
case class TempResolvedColumn(
child: Expression,
nameParts: Seq[String],
hasTried: Boolean = false) extends UnaryExpression
with Unevaluable {
// If it has been tried to be resolved but failed, mark it as unresolved so that other rules can
// try to resolve it again.
override lazy val resolved = child.resolved && !hasTried
override lazy val canonicalized = child.canonicalized
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
// `TempResolvedColumn` is logically a leaf node. We should not count it as a missing reference
// when resolving Filter/Sort/RepartitionByExpression. However, we should not make it a real
// leaf node, as rules that update expr IDs should update `TempResolvedColumn.child` as well.
override def references: AttributeSet = AttributeSet.empty
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
override def sql: String = child.sql
final override val nodePatterns: Seq[TreePattern] = Seq(TEMP_RESOLVED_COLUMN)
}
Original file line number Diff line number Diff line change
Expand Up @@ -433,26 +433,27 @@ case class OuterReference(e: NamedExpression)

/**
* A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the
* reference to a lateral column alias.
* reference to a lateral column alias. It will be restored back to [[UnresolvedAttribute]] if
* the lateral column alias can't be resolved, or become a normal resolved column in the rewritten
* plan after lateral column resolution. There should be no [[LateralColumnAliasReference]] beyond
* analyzer: if the plan passes all analysis check, then all [[LateralColumnAliasReference]] should
* already be removed.
*
* This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]].
* There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all
* analysis check, then all [[LateralColumnAliasReference]] should already be removed.
*
* @param ne the resolved [[NamedExpression]] by lateral column alias
* @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back
* @param ne the [[NamedExpression]] produced by column resolution. Can be [[UnresolvedAttribute]]
* if the referenced lateral column alias is not resolved yet.
* @param nameParts the name parts of the original [[UnresolvedAttribute]]. Used to restore back
* to [[UnresolvedAttribute]] when needed
* @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping
* and resolving LateralColumnAliasReference
* and resolving lateral column aliases and rewriting the query plan.
*/
case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute)
extends LeafExpression with NamedExpression with Unevaluable {
assert(ne.resolved)
override def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
assert(ne.resolved || ne.isInstanceOf[UnresolvedAttribute])
override def name: String = ne.name
override def exprId: ExprId = ne.exprId
override def qualifier: Seq[String] = ne.qualifier
override def toAttribute: Attribute = ne.toAttribute
override lazy val resolved = ne.resolved
override def newInstance(): NamedExpression =
LateralColumnAliasReference(ne.newInstance(), nameParts, a)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan}
Expand Down Expand Up @@ -159,12 +158,8 @@ object SubExprUtils extends PredicateHelper {
/**
* Wrap attributes in the expression with [[OuterReference]]s.
*/
def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = {
e.transform { case a: Attribute =>
val o = OuterReference(a)
nameParts.map(o.setTagValue(NAME_PARTS_FROM_UNRESOLVED_ATTR, _))
o
}.asInstanceOf[E]
def wrapOuterReference[E <: Expression](e: E): E = {
e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics" ::
"org.apache.spark.sql.catalyst.analysis.ResolveHigherOrderFunctions" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveInsertInto" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveMissingReferences" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNewInstance" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveOrdinalInOrderByAndGroupBy" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ object TreePattern extends Enumeration {
val UNION: Value = Value
val UNRESOLVED_RELATION: Value = Value
val UNRESOLVED_WITH: Value = Value
val TEMP_RESOLVED_COLUMN: Value = Value
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase {

test("Lateral alias of a complex type") {
// test both Project and Aggregate
// TODO(anchovyu): re-enable aggregate tests when fixed the having issue
val querySuffixes = Seq(""/* , s"FROM $testTable GROUP BY dept HAVING dept = 6" */)
val querySuffixes = Seq("", s"FROM $testTable GROUP BY dept HAVING dept = 6")
querySuffixes.foreach { querySuffix =>
checkAnswer(
sql(s"SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1 $querySuffix"),
Expand Down

0 comments on commit 3c40be2

Please sign in to comment.