Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
AddMetadataColumns ::
DeduplicateRelations ::
ResolveReferences ::
// Please do not insert any other rules in between. See the TODO comments in rule
// ResolveLateralColumnAliasReference for more details.
ResolveLateralColumnAliasReference ::
ResolveExpressionsWithNamePlaceholders ::
ResolveDeserializer ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.WindowExpression.hasWindowExpre
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING}
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -131,95 +131,97 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
(pList.exists(hasWindowExpression) && p.expressions.forall(_.resolved) && p.childrenResolved)
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING)) {
// It should not change the plan if `TempResolvedColumn` or `UnresolvedHaving` is present in
// the query plan. These plans need certain plan shape to get recognized and resolved by other
// rules, such as Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions.
// LCA resolution can break the plan shape, like adding Project above Aggregate.
plan
} else {
// phase 2: unwrap
plan.resolveOperatorsUpWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) {
case p @ Project(projectList, child) if ruleApplicableOnOperator(p, projectList)
&& projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
var aliasMap = AttributeMap.empty[AliasEntry]
val referencedAliases = collection.mutable.Set.empty[AliasEntry]
def unwrapLCAReference(e: NamedExpression): NamedExpression = {
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) =>
val aliasEntry = aliasMap.get(lcaRef.a).get
// If there is no chaining of lateral column alias reference, push down the alias
// and unwrap the LateralColumnAliasReference to the NamedExpression inside
// If there is chaining, don't resolve and save to future rounds
if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
referencedAliases += aliasEntry
lcaRef.ne
} else {
lcaRef
}
case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) =>
// It shouldn't happen, but restore to unresolved attribute to be safe.
UnresolvedAttribute(lcaRef.nameParts)
}.asInstanceOf[NamedExpression]
}
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaResolved = unwrapLCAReference(a)
// Insert the original alias instead of rewritten one to detect chained LCA
aliasMap += (a.toAttribute -> AliasEntry(a, idx))
lcaResolved
case (e, _) =>
unwrapLCAReference(e)
}
/** Internal application method. A hand-written bottom-up recursive traverse. */
private def apply0(plan: LogicalPlan): LogicalPlan = {
plan match {
case p: LogicalPlan if !p.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE) =>
p

if (referencedAliases.isEmpty) {
p
} else {
val outerProjectList = collection.mutable.Seq(newProjectList: _*)
val innerProjectList =
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*)
referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
outerProjectList.update(idx, alias.toAttribute)
innerProjectList += alias
}
p.copy(
projectList = outerProjectList.toSeq,
child = Project(innerProjectList.toSeq, child)
)
}
// It should not change the Aggregate (and thus the plan shape) if its parent is an
// UnresolvedHaving, to avoid breaking the shape pattern `UnresolvedHaving - Aggregate`
// matched by ResolveAggregateFunctions. See SPARK-42936 and SPARK-44714 for more details.
case u @ UnresolvedHaving(_, agg: Aggregate) =>
u.copy(child = agg.mapChildren(apply0))

case agg @ Aggregate(groupingExpressions, aggregateExpressions, _)
if ruleApplicableOnOperator(agg, aggregateExpressions)
&& aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
case pOriginal: Project if ruleApplicableOnOperator(pOriginal, pOriginal.projectList)
&& pOriginal.projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
val p @ Project(projectList, child) = pOriginal.mapChildren(apply0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bottom-up resolution. The rest of code is fully copied and has no change.

var aliasMap = AttributeMap.empty[AliasEntry]
val referencedAliases = collection.mutable.Set.empty[AliasEntry]
def unwrapLCAReference(e: NamedExpression): NamedExpression = {
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) =>
val aliasEntry = aliasMap.get(lcaRef.a).get
// If there is no chaining of lateral column alias reference, push down the alias
// and unwrap the LateralColumnAliasReference to the NamedExpression inside
// If there is chaining, don't resolve and save to future rounds
if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
referencedAliases += aliasEntry
lcaRef.ne
} else {
lcaRef
}
case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) =>
// It shouldn't happen, but restore to unresolved attribute to be safe.
UnresolvedAttribute(lcaRef.nameParts)
}.asInstanceOf[NamedExpression]
}
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaResolved = unwrapLCAReference(a)
// Insert the original alias instead of rewritten one to detect chained LCA
aliasMap += (a.toAttribute -> AliasEntry(a, idx))
lcaResolved
case (e, _) =>
unwrapLCAReference(e)
}

// Check if current Aggregate is eligible to lift up with Project: the aggregate
// expression only contains: 1) aggregate functions, 2) grouping expressions, 3) leaf
// expressions excluding attributes not in grouping expressions
// This check is to prevent unnecessary transformation on invalid plan, to guarantee it
// throws the same exception. For example, cases like non-aggregate expressions not
// in group by, once transformed, will throw a different exception: missing input.
def eligibleToLiftUp(exp: Expression): Boolean = {
exp match {
case _: AggregateExpression => true
case e if groupingExpressions.exists(_.semanticEquals(e)) => true
case a: Attribute => false
case s: ScalarSubquery if s.children.nonEmpty
&& !groupingExpressions.exists(_.semanticEquals(s)) => false
// Manually skip detection on function itself because it can be an aggregate function.
// This is to avoid expressions like sum(salary) over () eligible to lift up.
case WindowExpression(function, spec) =>
function.children.forall(eligibleToLiftUp) && eligibleToLiftUp(spec)
case e => e.children.forall(eligibleToLiftUp)
}
}
if (!aggregateExpressions.forall(eligibleToLiftUp)) {
return agg
if (referencedAliases.isEmpty) {
p
} else {
val outerProjectList = collection.mutable.Seq(newProjectList: _*)
val innerProjectList =
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*)
referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
outerProjectList.update(idx, alias.toAttribute)
innerProjectList += alias
}
p.copy(
projectList = outerProjectList.toSeq,
child = Project(innerProjectList.toSeq, child)
)
}

case aggOriginal: Aggregate
if ruleApplicableOnOperator(aggOriginal, aggOriginal.aggregateExpressions)
&& aggOriginal.aggregateExpressions.exists(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
val agg @ Aggregate(groupingExpressions, aggregateExpressions, _) =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bottom up resolution. The rest of code is fully copied and has no change (except the one I commented out).

aggOriginal.mapChildren(apply0)

// Check if current Aggregate is eligible to lift up with Project: the aggregate
// expression only contains: 1) aggregate functions, 2) grouping expressions, 3) leaf
// expressions excluding attributes not in grouping expressions
// This check is to prevent unnecessary transformation on invalid plan, to guarantee it
// throws the same exception. For example, cases like non-aggregate expressions not
// in group by, once transformed, will throw a different exception: missing input.
def eligibleToLiftUp(exp: Expression): Boolean = {
exp match {
case _: AggregateExpression => true
case e if groupingExpressions.exists(_.semanticEquals(e)) => true
case a: Attribute => false
case s: ScalarSubquery if s.children.nonEmpty
&& !groupingExpressions.exists(_.semanticEquals(s)) => false
// Manually skip detection on function itself because it can be an aggregate function.
// This is to avoid expressions like sum(salary) over () eligible to lift up.
case WindowExpression(function, spec) =>
function.children.forall(eligibleToLiftUp) && eligibleToLiftUp(spec)
case e => e.children.forall(eligibleToLiftUp)
}
}
if (!aggregateExpressions.forall(eligibleToLiftUp)) {
agg
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this line was return agg. That could be risky to return in a closure.. But for apply0 with or without return should be the same.

} else {
val newAggExprs = collection.mutable.Set.empty[NamedExpression]
val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression]
// Extract the expressions to keep in the Aggregate. Return the transformed expression
Expand Down Expand Up @@ -262,7 +264,33 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
projectList = projectExprs,
child = agg.copy(aggregateExpressions = newAggExprs.toSeq)
)
}
}

case p: LogicalPlan =>
p.mapChildren(apply0)
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN)) {
// It should not change the plan if `TempResolvedColumn` is present in the query plan. These
// plans need certain plan shape to get recognized and resolved by other rules, such as
// Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions. LCA resolution can
// break the plan shape, like adding Project above Aggregate.
// TODO: this condition only guarantees to keep the shape after the plan has
// `TempResolvedColumn`. However, it does not consider the case of breaking the shape even
// before `TempResolvedColumn` is generated by matching Filter/Sort - Aggregate in
// ResolveReferences. Currently the correctness of this case now relies on the rule
// application order, that ResolveReference is right before the application of
// ResolveLateralColumnAliasReference. The condition in the two rules guarantees that the
// case can never happen. We should consider to remove this order dependency but still assure
// correctness in the future.
plan
} else {
// phase 2: unwrap
apply0(plan)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -669,13 +669,42 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase {
s"FROM $testTable GROUP BY dept ORDER BY max(name)"),
Row(1, 1) :: Row(2, 2) :: Row(6, 6) :: Nil
)
checkAnswer(
sql("SELECT dept, avg(salary) AS a, a + 10 FROM employee GROUP BY dept ORDER BY max(name)"),
Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil
)
checkAnswer(
sql("SELECT dept, avg(salary) AS a, a + 10 AS b " +
"FROM employee GROUP BY dept ORDER BY max(name)"),
Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil
)
checkAnswer(
sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " +
"FROM employee GROUP BY dept ORDER BY max(name)"),
Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil
)

// having cond is resolved by aggregate's child
checkAnswer(
sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) AS a, a + 10 AS b " +
s"FROM $testTable GROUP BY dept HAVING max(name) = 'david'"),
Row(1250, 2, 11000, 11010) :: Nil
)
checkAnswer(
sql("SELECT dept, avg(salary) AS a, a + 10 " +
"FROM employee GROUP BY dept HAVING max(bonus) > 1200"),
Row(2, 11000, 11010) :: Nil
)
checkAnswer(
sql("SELECT dept, avg(salary) AS a, a + 10 AS b " +
"FROM employee GROUP BY dept HAVING max(bonus) > 1200"),
Row(2, 11000, 11010) :: Nil
)
checkAnswer(
sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " +
"FROM employee GROUP BY dept HAVING max(bonus) > 1200"),
Row(2, 11000, 11010) :: Nil
)
// having cond is resolved by aggregate itself
checkAnswer(
sql(s"SELECT avg(bonus) AS a, a FROM $testTable GROUP BY dept HAVING a > 1200"),
Expand Down Expand Up @@ -1139,4 +1168,120 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase {
// non group by or non aggregate function in Aggregate queries negative cases are covered in
// "Aggregate expressions not eligible to lift up, throws same error as inline".
}

test("Still resolves when Aggregate with LCA is not the direct child of Having") {
// Previously there was a limitation of lca that it can't resolve the query when it satisfies
// all the following criteria:
// 1) the main (outer) query has having clause
// 2) there is a window expression in the query
// 3) in the same SELECT list as the window expression in 2), there is an lca
// Though [UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_WITH_WINDOW_AND_HAVING] is
// still not supported, after SPARK-44714, a lot other limitations are
// lifted because it allows to resolve LCA when the query has UnresolvedHaving but its direct
// child does not contain an LCA.
// Testcases in this test focus on this change regarding enablement of resolution.

// CTE definition contains window and LCA; outer query contains having
checkAnswer(
sql(
s"""
|with w as (
| select name, dept, salary, rank() over (partition by dept order by salary) as r, r
| from $testTable
|)
|select dept
|from w
|group by dept
|having max(salary) > 10000
|""".stripMargin),
Row(2) :: Row(6) :: Nil
)
checkAnswer(
sql(
s"""
|with w as (
| select name, dept, salary, rank() over (partition by dept order by salary) as r, r
| from $testTable
|)
|select dept as d, d
|from w
|group by dept
|having max(salary) > 10000
|""".stripMargin),
Row(2, 2) :: Row(6, 6) :: Nil
)
checkAnswer(
sql(
s"""
|with w as (
| select name, dept, salary, rank() over (partition by dept order by salary) as r, r
| from $testTable
|)
|select dept as d
|from w
|group by dept
|having d = 2
|""".stripMargin),
Row(2) :: Nil
)

// inner subquery contains window and LCA; outer query contains having
checkAnswer(
sql(
s"""
|SELECT
| dept
|FROM
| (
| select
| name, dept, salary, rank() over (partition by dept order by salary) as r,
| 1 as a, a + 1 as e
| FROM
| $testTable
| ) AS inner_t
|GROUP BY
| dept
|HAVING max(salary) > 10000
|""".stripMargin),
Row(2) :: Row(6) :: Nil
)
checkAnswer(
sql(
s"""
|SELECT
| dept as d, d
|FROM
| (
| select
| name, dept, salary, rank() over (partition by dept order by salary) as r,
| 1 as a, a + 1 as e
| FROM
| $testTable
| ) AS inner_t
|GROUP BY
| dept
|HAVING max(salary) > 10000
|""".stripMargin),
Row(2, 2) :: Row(6, 6) :: Nil
)
checkAnswer(
sql(
s"""
|SELECT
| dept as d
|FROM
| (
| select
| name, dept, salary, rank() over (partition by dept order by salary) as r,
| 1 as a, a + 1 as e
| FROM
| $testTable
| ) AS inner_t
|GROUP BY
| dept
|HAVING d = 2
|""".stripMargin),
Row(2) :: Nil
)
}
}