Skip to content

Commit

Permalink
[SPARK-40773][SQL] Refactor checkCorrelationsInSubquery
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR refactors `checkCorrelationsInSubquery` in CheckAnalysis to use recursion instead of `foreachUp`.

### Why are the changes needed?

Currently, the logic in `checkCorrelationsInSubquery` is inefficient and difficult to understand. It uses `foreachUp` to traverse the subquery plan tree, and traverses down an entire subtree of a plan node to check whether it contains any outer references. We can use recursion instead to traverse the plan tree only once to improve the performance and readability.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing unit tests.

Closes apache#38226 from allisonwang-db/spark-40773-check-subquery.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and SandishKumarHN committed Dec 12, 2022
1 parent 877043c commit 00aa52f
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}
}

// Make sure a plan's subtree does not contain outer references
def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = {
if (hasOuterReferences(p)) {
// Make sure expressions of a plan do not contain outer references.
def failOnOuterReferenceInPlan(p: LogicalPlan): Unit = {
if (p.expressions.exists(containsOuter)) {
p.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED",
Expand Down Expand Up @@ -1078,18 +1078,24 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}
}

val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]
// Recursively check invalid outer references in the plan.
def checkPlan(
plan: LogicalPlan,
aggregated: Boolean = false,
canContainOuter: Boolean = true): Unit = {

if (!canContainOuter) {
failOnOuterReferenceInPlan(plan)
}

// Simplify the predicates before validating any unsupported correlation patterns in the plan.
AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
// Approve operators allowed in a correlated subquery
// There are 4 categories:
// 1. Operators that are allowed anywhere in a correlated subquery, and,
// by definition of the operators, they either do not contain
// any columns or cannot host outer references.
// 2. Operators that are allowed anywhere in a correlated subquery
// so long as they do not host outer references.
// 3. Operators that need special handlings. These operators are
// 3. Operators that need special handling. These operators are
// Filter, Join, Aggregate, and Generate.
//
// Any operators that are not in the above list are allowed
Expand All @@ -1099,99 +1105,114 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// A correlation path is defined as the sub-tree of all the operators that
// are on the path from the operator hosting the correlated expressions
// up to the operator producing the correlated values.
plan match {
// Category 1:
// ResolvedHint, LeafNode, Repartition, and SubqueryAlias
case p @ (_: ResolvedHint | _: LeafNode | _: Repartition | _: SubqueryAlias) =>
p.children.foreach(child => checkPlan(child, aggregated, canContainOuter))

// Category 2:
// These operators can be anywhere in a correlated subquery.
// so long as they do not host outer references in the operators.
case p: Project =>
failOnInvalidOuterReference(p)
checkPlan(p.child, aggregated, canContainOuter)

case s: Sort =>
failOnInvalidOuterReference(s)
checkPlan(s.child, aggregated, canContainOuter)

case r: RepartitionByExpression =>
failOnInvalidOuterReference(r)
checkPlan(r.child, aggregated, canContainOuter)

case l: LateralJoin =>
failOnInvalidOuterReference(l)
checkPlan(l.child, aggregated, canContainOuter)

// Category 3:
// Filter is one of the two operators allowed to host correlated expressions.
// The other operator is Join. Filter can be anywhere in a correlated subquery.
case f: Filter =>
failOnInvalidOuterReference(f)
val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
val unsupportedPredicates = correlated.filterNot(DecorrelateInnerQuery.canPullUpOverAgg)
if (aggregated) {
failOnUnsupportedCorrelatedPredicate(unsupportedPredicates, f)
}
checkPlan(f.child, aggregated, canContainOuter)

// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only supported correlated equality predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
case a: Aggregate =>
failOnInvalidOuterReference(a)
checkPlan(a.child, aggregated = true, canContainOuter)

// Distinct does not host any correlated expressions, but during the optimization phase
// it will be rewritten as Aggregate, which can only be on a correlation path if the
// correlation contains only the supported correlated equality predicates.
// Only block it for lateral subqueries because scalar subqueries must be aggregated
// and it does not impact the results for IN/EXISTS subqueries.
case d: Distinct =>
checkPlan(d.child, aggregated = isLateral, canContainOuter)

// Join can host correlated expressions.
case j @ Join(left, right, joinType, _, _) =>
failOnInvalidOuterReference(j)
joinType match {
// Inner join, like Filter, can be anywhere.
case _: InnerLike =>
j.children.foreach(child => checkPlan(child, aggregated, canContainOuter))

// Left outer join's right operand cannot be on a correlation path.
// LeftAnti and ExistenceJoin are special cases of LeftOuter.
// Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame
// so it should not show up here in Analysis phase. This is just a safety net.
//
// LeftSemi does not allow output from the right operand.
// Any correlated references in the subplan
// of the right operand cannot be pulled up.
case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
checkPlan(left, aggregated, canContainOuter)
checkPlan(right, aggregated, canContainOuter = false)

// Likewise, Right outer join's left operand cannot be on a correlation path.
case RightOuter =>
checkPlan(left, aggregated, canContainOuter = false)
checkPlan(right, aggregated, canContainOuter)

// Any other join types not explicitly listed above,
// including Full outer join, are treated as Category 4.
case _ =>
j.children.foreach(child => checkPlan(child, aggregated, canContainOuter = false))
}

// Category 1:
// ResolvedHint, LeafNode, Repartition, and SubqueryAlias
case _: ResolvedHint | _: LeafNode | _: Repartition | _: SubqueryAlias =>

// Category 2:
// These operators can be anywhere in a correlated subquery.
// so long as they do not host outer references in the operators.
case p: Project =>
failOnInvalidOuterReference(p)

case s: Sort =>
failOnInvalidOuterReference(s)

case r: RepartitionByExpression =>
failOnInvalidOuterReference(r)

case l: LateralJoin =>
failOnInvalidOuterReference(l)

// Category 3:
// Filter is one of the two operators allowed to host correlated expressions.
// The other operator is Join. Filter can be anywhere in a correlated subquery.
case f: Filter =>
val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
unsupportedPredicates ++= correlated.filterNot(DecorrelateInnerQuery.canPullUpOverAgg)
failOnInvalidOuterReference(f)

// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only supported correlated equality predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
case a: Aggregate =>
failOnInvalidOuterReference(a)
failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a)

// Distinct does not host any correlated expressions, but during the optimization phase
// it will be rewritten as Aggregate, which can only be on a correlation path if the
// correlation contains only the supported correlated equality predicates.
// Only block it for lateral subqueries because scalar subqueries must be aggregated
// and it does not impact the results for IN/EXISTS subqueries.
case d: Distinct =>
if (isLateral) {
failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, d)
}

// Join can host correlated expressions.
case j @ Join(left, right, joinType, _, _) =>
joinType match {
// Inner join, like Filter, can be anywhere.
case _: InnerLike =>
failOnInvalidOuterReference(j)

// Left outer join's right operand cannot be on a correlation path.
// LeftAnti and ExistenceJoin are special cases of LeftOuter.
// Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame
// so it should not show up here in Analysis phase. This is just a safety net.
//
// LeftSemi does not allow output from the right operand.
// Any correlated references in the subplan
// of the right operand cannot be pulled up.
case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
failOnInvalidOuterReference(j)
failOnOuterReferenceInSubTree(right)

// Likewise, Right outer join's left operand cannot be on a correlation path.
case RightOuter =>
failOnInvalidOuterReference(j)
failOnOuterReferenceInSubTree(left)

// Any other join types not explicitly listed above,
// including Full outer join, are treated as Category 4.
case _ =>
failOnOuterReferenceInSubTree(j)
}
// Generator with join=true, i.e., expressed with
// LATERAL VIEW [OUTER], similar to inner join,
// allows to have correlation under it
// but must not host any outer references.
// Note:
// Generator with requiredChildOutput.isEmpty is treated as Category 4.
case g: Generate if g.requiredChildOutput.nonEmpty =>
failOnInvalidOuterReference(g)
checkPlan(g.child, aggregated, canContainOuter)

// Category 4: Any other operators not in the above 3 categories
// cannot be on a correlation path, that is they are allowed only
// under a correlation point but they and their descendant operators
// are not allowed to have any correlated expressions.
case p =>
p.children.foreach(p => checkPlan(p, aggregated, canContainOuter = false))
}
}

// Generator with join=true, i.e., expressed with
// LATERAL VIEW [OUTER], similar to inner join,
// allows to have correlation under it
// but must not host any outer references.
// Note:
// Generator with requiredChildOutput.isEmpty is treated as Category 4.
case g: Generate if g.requiredChildOutput.nonEmpty =>
failOnInvalidOuterReference(g)

// Category 4: Any other operators not in the above 3 categories
// cannot be on a correlation path, that is they are allowed only
// under a correlation point but they and their descendant operators
// are not allowed to have any correlated expressions.
case p =>
failOnOuterReferenceInSubTree(p)
}}
// Simplify the predicates before validating any unsupported correlation patterns in the plan.
AnalysisHelper.allowInvokingTransformsInAnalyzer {
checkPlan(BooleanSimplification(sub))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ org.apache.spark.sql.AnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
"messageParameters" : {
"treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nAggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS udf(max(udf(v)))#x]\n+- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n +- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n"
"treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nFilter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n+- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n"
},
"queryContext" : [ {
"objectType" : "",
Expand Down
20 changes: 9 additions & 11 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -891,9 +891,9 @@ class SubquerySuite extends QueryTest
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "(select c1 from t2 where t1.c1 = 2) t2",
start = 110,
stop = 147))
fragment = "select c1 from t2 where t1.c1 = 2",
start = 111,
stop = 143))

// Right outer join (ROJ) in EXISTS subquery context
val exception2 = intercept[AnalysisException] {
Expand All @@ -913,9 +913,9 @@ class SubquerySuite extends QueryTest
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "(select c1 from t2 where t1.c1 = 2) t2",
start = 74,
stop = 111))
fragment = "select c1 from t2 where t1.c1 = 2",
start = 75,
stop = 107))

// SPARK-18578: Full outer join (FOJ) in scalar subquery context
val exception3 = intercept[AnalysisException] {
Expand All @@ -934,11 +934,9 @@ class SubquerySuite extends QueryTest
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment =
"""full join t3
| on t2.c1=t3.c1""".stripMargin,
start = 112,
stop = 154))
fragment = "select c1 from t2 where t1.c1 = 2 and t1.c1=t2.c1",
start = 41,
stop = 90))
}
}

Expand Down

0 comments on commit 00aa52f

Please sign in to comment.