Skip to content

Commit

Permalink
[SPARK-35545][SQL] Split SubqueryExpression's children field into out…
Browse files Browse the repository at this point in the history
…er attributes and join conditions

### What changes were proposed in this pull request?
This PR refactors `SubqueryExpression` class. It removes the children field from SubqueryExpression's constructor and adds `outerAttrs` and `joinCond`.

### Why are the changes needed?
Currently, the children field of a subquery expression is used to store both collected outer references in the subquery plan and join conditions after correlated predicates are pulled up.

For example:
`SELECT (SELECT max(c1) FROM t1 WHERE t1.c1 = t2.c1) FROM t2`

During the analysis phase, outer references in the subquery are stored in the children field: `scalar-subquery [t2.c1]`, but after the optimizer rule `PullupCorrelatedPredicates`, the children field will be used to store the join conditions, which contain both the inner and the outer references: `scalar-subquery [t1.c1 = t2.c1]`. This is why the references of SubqueryExpression excludes the inner plan's output:
https://github.com/apache/spark/blob/29ed1a2de42e7a663f764192fce157a9f23029b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala#L68-L69

This can be confusing and error-prone. The references for a subquery expression should always be defined as outer attribute references.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

Closes #32687 from allisonwang-db/refactor-subquery-expr.

Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and cloud-fan committed May 31, 2021
1 parent 1a55019 commit 806da9d
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 51 deletions.
Expand Up @@ -2343,11 +2343,11 @@ class Analyzer(override val catalogManager: CatalogManager)
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _))
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
Expand Down
Expand Up @@ -744,17 +744,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
checkAnalysis(expr.plan)

expr match {
case ScalarSubquery(query, conditions, _) =>
case ScalarSubquery(query, outerAttrs, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
}

if (conditions.nonEmpty) {
if (outerAttrs.nonEmpty) {
cleanQueryInScalarSubquery(query) match {
case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a)
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail")
}
Expand Down
Expand Up @@ -322,7 +322,7 @@ abstract class TypeCoercionBase {

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _))
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions))
if !i.resolved && lhs.length == sub.output.length =>
// LHS is the value expressions of IN subquery.
// RHS is the subquery output.
Expand All @@ -345,7 +345,7 @@ abstract class TypeCoercionBase {
}

val newSub = Project(castedRhs, sub)
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output))
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions))
} else {
i
}
Expand Down
Expand Up @@ -59,14 +59,22 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {

/**
* A base interface for expressions that contain a [[LogicalPlan]].
*
* @param plan: the subquery plan
* @param outerAttrs: the outer references in the subquery plan
* @param exprId: ID of the expression
* @param joinCond: the join conditions with the outer query. It contains both inner and outer
* query references.
*/
abstract class SubqueryExpression(
plan: LogicalPlan,
children: Seq[Expression],
exprId: ExprId) extends PlanExpression[LogicalPlan] {
outerAttrs: Seq[Expression],
exprId: ExprId,
joinCond: Seq[Expression] = Nil) extends PlanExpression[LogicalPlan] {
override lazy val resolved: Boolean = childrenResolved && plan.resolved
override lazy val references: AttributeSet =
if (plan.resolved) super.references -- plan.outputSet else super.references
AttributeSet.fromAttributeSets(outerAttrs.map(_.references))
override def children: Seq[Expression] = outerAttrs ++ joinCond
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
override def semanticEquals(o: Expression): Boolean = o match {
case p: SubqueryExpression =>
Expand Down Expand Up @@ -240,9 +248,10 @@ object SubExprUtils extends PredicateHelper {
*/
case class ScalarSubquery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
override def dataType: DataType = {
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
plan.schema.fields.head.dataType
Expand All @@ -253,12 +262,16 @@ case class ScalarSubquery(
override lazy val canonicalized: Expression = {
ScalarSubquery(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
outerAttrs.map(_.canonicalized),
ExprId(0),
joinCond.map(_.canonicalized))
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
newChildren: IndexedSeq[Expression]): ScalarSubquery =
copy(
outerAttrs = newChildren.take(outerAttrs.size),
joinCond = newChildren.drop(outerAttrs.size))

final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY)
}
Expand Down Expand Up @@ -286,10 +299,11 @@ object ScalarSubquery {
*/
case class ListQuery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
childOutputs: Seq[Attribute] = Seq.empty)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
childOutputs: Seq[Attribute] = Seq.empty,
joinCond: Seq[Expression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
override def dataType: DataType = if (childOutputs.length > 1) {
childOutputs.toStructType
} else {
Expand All @@ -302,13 +316,16 @@ case class ListQuery(
override lazy val canonicalized: Expression = {
ListQuery(
plan.canonicalized,
children.map(_.canonicalized),
outerAttrs.map(_.canonicalized),
ExprId(0),
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]),
joinCond.map(_.canonicalized))
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
copy(children = newChildren)
copy(
outerAttrs = newChildren.take(outerAttrs.size),
joinCond = newChildren.drop(outerAttrs.size))

final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY)
}
Expand Down Expand Up @@ -341,21 +358,25 @@ case class ListQuery(
*/
case class Exists(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable {
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Predicate with Unevaluable {
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
override def toString: String = s"exists#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
Exists(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
outerAttrs.map(_.canonicalized),
ExprId(0),
joinCond.map(_.canonicalized))
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
copy(children = newChildren)
copy(
outerAttrs = newChildren.take(outerAttrs.size),
joinCond = newChildren.drop(outerAttrs.size))

final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY)
}
Expand Up @@ -110,19 +110,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {

// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
case (p, Exists(sub, _, _, conditions)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
buildJoin(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(Exists(sub, conditions, _))) =>
case (p, Not(Exists(sub, _, _, conditions))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
buildJoin(outerPlan, sub, LeftAnti, joinCond)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
Expand Down Expand Up @@ -166,12 +166,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
var newPlan = plan
val newExprs = exprs.map { e =>
e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
case Exists(sub, conditions, _) =>
case Exists(sub, _, _, conditions) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan =
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case Not(InSubquery(values, ListQuery(sub, conditions, _, _))) =>
case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
Expand All @@ -192,7 +192,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), Some(finalJoinCond), JoinHint.NONE)
Not(exists)
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
case InSubquery(values, ListQuery(sub, _, _, _, conditions)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
Expand Down Expand Up @@ -306,15 +306,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper

plan.transformExpressionsWithPruning(_.containsAnyPattern(
SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
case ScalarSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, outerPlans)
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
case Exists(sub, children, exprId) if children.nonEmpty =>
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Exists(newPlan, getJoinCondition(newCond, children), exprId)
case ListQuery(sub, children, exprId, childOutputs) if children.nonEmpty =>
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions))
case ListQuery(sub, children, exprId, childOutputs, conditions) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
ListQuery(newPlan, getJoinCondition(newCond, children), exprId, childOutputs)
ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
}
}

Expand Down Expand Up @@ -524,7 +524,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, conditions, _)) =>
case (currentChild, ScalarSubquery(sub, _, _, conditions)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head

Expand Down
Expand Up @@ -118,14 +118,14 @@ case class InsertAdaptiveSparkPlan(
return subqueryMap.toMap
}
plan.foreach(_.expressions.foreach(_.foreach {
case expressions.ScalarSubquery(p, _, exprId)
case expressions.ScalarSubquery(p, _, exprId, _)
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(p)
verifyAdaptivePlan(executedPlan, p)
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", executedPlan)
subqueryMap.put(exprId.id, subquery)
case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
case expressions.InSubquery(_, ListQuery(query, _, exprId, _, _))
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(query)
verifyAdaptivePlan(executedPlan, query)
Expand Down
Expand Up @@ -31,9 +31,9 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
case expressions.ScalarSubquery(_, _, exprId) =>
case expressions.ScalarSubquery(_, _, exprId, _) =>
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) =>
val expr = if (values.length == 1) {
values.head
} else {
Expand Down
Expand Up @@ -185,7 +185,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
SubqueryExec.createForScalarSubquery(
s"scalar-subquery#${subquery.exprId.id}", executedPlan),
subquery.exprId)
case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) =>
case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _)) =>
val expr = if (values.length == 1) {
values.head
} else {
Expand Down
Expand Up @@ -970,7 +970,7 @@ class PlanResolutionSuite extends AnalysisTest {
query match {
case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
_, _, _) =>
_, _, _, _) =>
assert(projects.size == 1 && projects.head.name == "s.name")
assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
case o => fail("Unexpected subquery: \n" + o.treeString)
Expand Down Expand Up @@ -1046,7 +1046,7 @@ class PlanResolutionSuite extends AnalysisTest {
query match {
case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
_, _, _) =>
_, _, _, _) =>
assert(projects.size == 1 && projects.head.name == "s.name")
assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
case o => fail("Unexpected subquery: \n" + o.treeString)
Expand Down

0 comments on commit 806da9d

Please sign in to comment.