Skip to content

Commit

Permalink
[SPARK-27561][SQL] Support implicit lateral column alias resolution o…
Browse files Browse the repository at this point in the history
…n Project

### What changes were proposed in this pull request?
This PR implements a new feature: Implicit lateral column alias  on `Project` case, controlled by `spark.sql.lateralColumnAlias.enableImplicitResolution` temporarily (default false now, but will turn on this conf once the feature is completely merged).

#### Lateral column alias
View https://issues.apache.org/jira/browse/SPARK-27561 for more details on lateral column alias.
There are two main cases to support: LCA in Project, and LCA in Aggregate.
```sql
-- LCA in Project. The base_salary references an attribute defined by a previous alias
SELECT salary AS base_salary, base_salary + bonus AS total_salary
FROM employee

-- LCA in Aggregate. The avg_salary references an attribute defined by a previous alias
SELECT dept, average(salary) AS avg_salary, avg_salary + average(bonus)
FROM employee
GROUP BY dept
```
This **implicit** lateral column alias (no explicit keyword, e.g. `lateral.base_salary`) should be supported.

#### High level design
This PR defines a new Resolution rule, `ResolveLateralColumnAlias` to resolve the implicit lateral column alias, covering the `Project` case.
It introduces a new leaf node NamedExpression, `LateralColumnAliasReference`, as a placeholder used to hold a referenced that has been temporarily resolved as the reference to a lateral column alias.

The whole process is generally divided into two phases:
1) recognize **resolved** lateral alias, wrap the attributes referencing them with `LateralColumnAliasReference`.
 2) when the whole operator is resolved, unwrap `LateralColumnAliasReference`. For Project, it further resolves the attributes and push down the referenced lateral aliases to the new Project.

For example:
```
// Before
Project [age AS a, 'a + 1]
+- Child

// After phase 1
Project [age AS a, lateralalias(a) + 1]
+- Child

// After phase 2
Project [a, a + 1]
+- Project [child output, age AS a]
   +- Child
```

#### Resolution order
Given this new rule, the name resolution order will be (higher -> lower):
```
local table column > local metadata attribute > local lateral column alias > all others (outer reference of subquery, parameters of SQL UDF, ..)
```

There is a recent refactor that moves the creation of `OuterReference` in the Resolution batch: #38851.
Because lateral column alias has higher resolution priority than outer reference, it will try to resolve an `OuterReference` using lateral column alias, similar as an `UnresolvedAttribute`. If success, it strips `OuterReference` and also wraps it with `LateralColumnAliasReference`.

### Why are the changes needed?
The lateral column alias is a popular feature wanted for a long time. It is supported by lots of other database vendors (Redshift, snowflake, etc) and provides a better user experience.

### Does this PR introduce _any_ user-facing change?
Yes, as shown in the above example, it will be able to resolve lateral column alias. I will write the migration guide or release note when most PRs of this feature are merged.

### How was this patch tested?
Existing tests and newly added tests.

Closes #38776 from anchovYu/SPARK-27561-refactor.

Authored-by: Xinyi Yu <xinyi.yu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
anchovYu authored and cloud-fan committed Dec 13, 2022
1 parent a2ceff2 commit 7e9b88b
Show file tree
Hide file tree
Showing 13 changed files with 686 additions and 7 deletions.
6 changes: 6 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
],
"sqlState" : "42000"
},
"AMBIGUOUS_LATERAL_COLUMN_ALIAS" : {
"message" : [
"Lateral column alias <name> is ambiguous and has <n> matches."
],
"sqlState" : "42000"
},
"AMBIGUOUS_REFERENCE" : {
"message" : [
"Reference <name> is ambiguous, could be: <referenceNames>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])

override def contains(k: Attribute): Boolean = get(k).isDefined

override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv
override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
AttributeMap(baseMap.values.toMap + kv)

override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])

override def contains(k: Attribute): Boolean = get(k).isDefined

override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
AttributeMap(baseMap.values.toMap + kv)

override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] =
baseMap.values.toMap + (key -> value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -288,6 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager)
AddMetadataColumns ::
DeduplicateRelations ::
ResolveReferences ::
WrapLateralColumnAliasReference ::
ResolveLateralColumnAliasReference ::
ResolveExpressionsWithNamePlaceholders ::
ResolveDeserializer ::
ResolveNewInstance ::
Expand Down Expand Up @@ -1672,7 +1674,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// Only Project and Aggregate can host star expressions.
case u @ (_: Project | _: Aggregate) =>
Try(s.expand(u.children.head, resolver)) match {
case Success(expanded) => expanded.map(wrapOuterReference)
case Success(expanded) => expanded.map(wrapOuterReference(_))
case Failure(_) => throw e
}
// Do not use the outer plan to resolve the star expression
Expand Down Expand Up @@ -1761,6 +1763,117 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

/**
* The first phase to resolve lateral column alias. See comments in
* [[ResolveLateralColumnAliasReference]] for more detailed explanation.
*/
object WrapLateralColumnAliasReference extends Rule[LogicalPlan] {
import ResolveLateralColumnAliasReference.AliasEntry

private def insertIntoAliasMap(
a: Alias,
idx: Int,
aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = {
val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry])
aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx)))
}

/**
* Use the given lateral alias to resolve the unresolved attribute with the name parts.
*
* Construct a dummy plan with the given lateral alias as project list, use the output of the
* plan to resolve.
* @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve.
*/
private def resolveByLateralAlias(
nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = {
val resolvedAttr = resolveExpressionByPlanOutput(
expr = UnresolvedAttribute(nameParts),
plan = LocalRelation(Seq(lateralAlias.toAttribute)),
throws = false
).asInstanceOf[NamedExpression]
if (resolvedAttr.resolved) {
Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute))
} else {
None
}
}

/**
* Recognize all the attributes in the given expression that reference lateral column aliases
* by looking up the alias map. Resolve these attributes and replace by wrapping with
* [[LateralColumnAliasReference]].
*
* @param currentPlan Because lateral alias has lower resolution priority than table columns,
* the current plan is needed to first try resolving the attribute by its
* children
*/
private def wrapLCARef(
e: NamedExpression,
currentPlan: LogicalPlan,
aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = {
e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) {
case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) &&
resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] =>
val aliases = aliasMap.get(u.nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n)
case n if n == 1 && aliases.head.alias.resolved =>
// Only resolved alias can be the lateral column alias
// The lateral alias can be a struct and have nested field, need to construct
// a dummy plan to resolve the expression
resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u)
case _ => u
}
case o: OuterReference
if aliasMap.contains(
o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR)
.map(_.head)
.getOrElse(o.name)) =>
// handle OuterReference exactly same as UnresolvedAttribute
val nameParts = o
.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR)
.getOrElse(Seq(o.name))
val aliases = aliasMap.get(nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n)
case n if n == 1 && aliases.head.alias.resolved =>
resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o)
case _ => o
}
}.asInstanceOf[NamedExpression]
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else {
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) {
case p @ Project(projectList, _) if p.childrenResolved
&& !ResolveReferences.containsStar(projectList)
&& projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>
var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias]
// Insert the LCA-resolved alias instead of the unresolved one into map. If it is
// resolved, it can be referenced as LCA by later expressions (chaining).
// Unresolved Alias is also added to the map to perform ambiguous name check, but
// only resolved alias can be LCA.
aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap)
lcaWrapped
case (e, _) =>
wrapLCARef(e, p, aliasMap)
}
p.copy(projectList = newProjectList)
}
}
}
}

private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer]))
}
Expand Down Expand Up @@ -2143,7 +2256,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case u @ UnresolvedAttribute(nameParts) => withPosition(u) {
try {
AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match {
case Some(resolved) => wrapOuterReference(resolved)
case Some(resolved) => wrapOuterReference(resolved, Some(nameParts))
case None => u
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
Expand Down Expand Up @@ -638,6 +638,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case UnresolvedWindowExpression(_, windowSpec) =>
throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name)
})
// This should not happen, resolved Project or Aggregate should restore or resolve
// all lateral column alias references. Add check for extra safe.
projectList.foreach(_.transformDownWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if p.resolved =>
throw SparkException.internalError("Resolved Project should not contain " +
s"any LateralColumnAliasReference.\nDebugging information: plan: $p",
context = lcaRef.origin.getQueryContext,
summary = lcaRef.origin.context.summary)
})

case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
Expand Down Expand Up @@ -714,6 +724,19 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
"operator" -> other.nodeName,
"invalidExprSqls" -> invalidExprSqls.mkString(", ")))

// This should not happen, resolved Project or Aggregate should restore or resolve
// all lateral column alias references. Add check for extra safe.
case agg @ Aggregate(_, aggList, _)
if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved =>
aggList.foreach(_.transformDownWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference =>
throw SparkException.internalError("Resolved Aggregate should not contain " +
s"any LateralColumnAliasReference.\nDebugging information: plan: $agg",
context = lcaRef.origin.getQueryContext,
summary = lcaRef.origin.context.summary)
})

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE
import org.apache.spark.sql.internal.SQLConf

/**
* This rule is the second phase to resolve lateral column alias.
*
* Resolve lateral column alias, which references the alias defined previously in the SELECT list.
* Plan-wise, it handles two types of operators: Project and Aggregate.
* - in Project, pushing down the referenced lateral alias into a newly created Project, resolve
* the attributes referencing these aliases
* - in Aggregate TODO.
*
* The whole process is generally divided into two phases:
* 1) recognize resolved lateral alias, wrap the attributes referencing them with
* [[LateralColumnAliasReference]]
* 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]].
* For Project, it further resolves the attributes and push down the referenced lateral aliases.
* For Aggregate, TODO
*
* Example for Project:
* Before rewrite:
* Project [age AS a, 'a + 1]
* +- Child
*
* After phase 1:
* Project [age AS a, lateralalias(a) + 1]
* +- Child
*
* After phase 2:
* Project [a, a + 1]
* +- Project [child output, age AS a]
* +- Child
*
* Example for Aggregate TODO
*
*
* The name resolution priority:
* local table column > local lateral column alias > outer reference
*
* Because lateral column alias has higher resolution priority than outer reference, it will try
* to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an
* [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with
* [[LateralColumnAliasReference]].
*/
object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
case class AliasEntry(alias: Alias, index: Int)

/**
* A tag to store the nameParts from the original unresolved attribute.
* It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back
* to [[LateralColumnAliasReference]].
*/
val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr")

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else {
// phase 2: unwrap
plan.resolveOperatorsUpWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) {
case p @ Project(projectList, child) if p.resolved
&& projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
var aliasMap = AttributeMap.empty[AliasEntry]
val referencedAliases = collection.mutable.Set.empty[AliasEntry]
def unwrapLCAReference(e: NamedExpression): NamedExpression = {
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) =>
val aliasEntry = aliasMap.get(lcaRef.a).get
// If there is no chaining of lateral column alias reference, push down the alias
// and unwrap the LateralColumnAliasReference to the NamedExpression inside
// If there is chaining, don't resolve and save to future rounds
if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
referencedAliases += aliasEntry
lcaRef.ne
} else {
lcaRef
}
case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) =>
// It shouldn't happen, but restore to unresolved attribute to be safe.
UnresolvedAttribute(lcaRef.nameParts)
}.asInstanceOf[NamedExpression]
}
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaResolved = unwrapLCAReference(a)
// Insert the original alias instead of rewritten one to detect chained LCA
aliasMap += (a.toAttribute -> AliasEntry(a, idx))
lcaResolved
case (e, _) =>
unwrapLCAReference(e)
}

if (referencedAliases.isEmpty) {
p
} else {
val outerProjectList = collection.mutable.Seq(newProjectList: _*)
val innerProjectList =
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*)
referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
outerProjectList.update(idx, alias.toAttribute)
innerProjectList += alias
}
p.copy(
projectList = outerProjectList.toSeq,
child = Project(innerProjectList.toSeq, child)
)
}
}
}
}
}

0 comments on commit 7e9b88b

Please sign in to comment.