Skip to content

Commit

Permalink
[SPARK-36063][SQL] Optimize OneRowRelation subqueries
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR adds optimization for scalar and lateral subqueries with OneRowRelation as leaf nodes. It inlines such subqueries before decorrelation to avoid rewriting them as left outer joins. It also introduces a flag to turn on/off this optimization: `spark.sql.optimizer.optimizeOneRowRelationSubquery` (default: True).

For example:
```sql
select (select c1) from t
```
Analyzed plan:
```
Project [scalar-subquery#17 [c1#18] AS scalarsubquery(c1)#22]
:  +- Project [outer(c1#18)]
:     +- OneRowRelation
+- LocalRelation [c1#18, c2#19]
```

Optimized plan before this PR:
```
Project [c1#18#25 AS scalarsubquery(c1)#22]
+- Join LeftOuter, (c1#24 <=> c1#18)
   :- LocalRelation [c1#18]
   +- Aggregate [c1#18], [c1#18 AS c1#18#25, c1#18 AS c1#24]
      +- LocalRelation [c1#18]
```

Optimized plan after this PR:
```
LocalRelation [scalarsubquery(c1)#22]
```

### Why are the changes needed?
To optimize query plans.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added new unit tests.

Closes #33284 from allisonwang-db/spark-36063-optimize-subquery-one-row-relation.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Jul 22, 2021
1 parent dcb7db5 commit de8e4be
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ package object dsl {
condition: Option[Expression] = None): LogicalPlan =
Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)

def lateralJoin(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
condition: Option[Expression] = None): LogicalPlan = {
LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition)
}

def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder](
otherPlan: LogicalPlan,
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ object SubExprUtils extends PredicateHelper {
/**
* Returns an expression after removing the OuterReference shell.
*/
def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r }
def stripOuterReference[E <: Expression](e: E): E = {
e.transform { case OuterReference(r) => r }.asInstanceOf[E]
}

/**
* Returns the list of expressions after removing the OuterReference shell from each of
* the expression.
*/
def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference)
def stripOuterReferences[E <: Expression](e: Seq[E]): Seq[E] = e.map(stripOuterReference)

/**
* Returns the logical plan after removing the OuterReference shell from all the expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ object DecorrelateInnerQuery extends PredicateHelper {
expressions.map(replaceOuterReference(_, outerReferenceMap))
}

/**
* Replace all outer references in the given named expressions and keep the output
* attributes unchanged.
*/
private def replaceOuterInNamedExpressions(
expressions: Seq[NamedExpression],
outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = {
expressions.map { expr =>
val newExpr = replaceOuterReference(expr, outerReferenceMap)
if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) {
Alias(newExpr, expr.name)(expr.exprId)
} else {
newExpr
}
}
}

/**
* Return all references that are presented in the join conditions but not in the output
* of the given named expressions.
Expand Down Expand Up @@ -429,8 +446,9 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated)
// Replace all outer references in the original project list.
val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
// Replace all outer references in the original project list and keep the output
// attributes unchanged.
val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap)
// Preserve required domain attributes in the join condition by adding the missing
// references to the new project list.
val referencesToAdd = missingReferences(newProjectList, joinCond)
Expand All @@ -442,9 +460,10 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated = true)
// Replace all outer references in grouping and aggregate expressions.
// Replace all outer references in grouping and aggregate expressions, and keep
// the output attributes unchanged.
val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap)
val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap)
// Add all required domain attributes to both grouping and aggregate expressions.
val referencesToAdd = missingReferences(newAggExpr, joinCond)
val newAggregate = a.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// non-nullable when an empty relation child of a Union is removed
UpdateAttributeNullability) ::
Batch("Pullup Correlated Expressions", Once,
OptimizeOneRowRelationSubquery,
PullupCorrelatedPredicates) ::
// Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense
// to enforce idempotence on it and we change this batch from Once to FixedPoint(1).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
Expand Down Expand Up @@ -711,3 +712,47 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] {
Join(left, newRight, joinType, newCond, JoinHint.NONE)
}
}

/**
* This rule optimizes subqueries with OneRowRelation as leaf nodes.
*/
object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {

object OneRowSubquery {
def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = {
CollapseProject(EliminateSubqueryAliases(plan)) match {
case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList))
case _ => None
}
}
}

private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = {
plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined
}

/**
* Rewrite a subquery expression into one or more expressions. The rewrite can only be done
* if there is no nested subqueries in the subquery plan.
*/
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty =>
Project(left.output ++ projectList, left)
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(projectList.size == 1)
projectList.head
}
}

def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) {
plan
} else {
rewrite(plan)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,23 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
subqueries ++ subqueries.flatMap(_.subqueriesAll)
}

/**
* Returns a copy of this node where the given partial function has been recursively applied
* first to the subqueries in this node's children, then this node's children, and finally
* this node itself (post-order). When the partial function does not apply to a given node,
* it is left unchanged.
*/
def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
transformUp { case plan =>
val transformed = plan transformExpressionsUp {
case planExpression: PlanExpression[PlanType] =>
val newPlan = planExpression.plan.transformUpWithSubqueries(f)
planExpression.withNewPlan(newPlan)
}
f.applyOrElse[PlanType, PlanType](transformed, identity)
}
}

/**
* A variant of `collect`. This method not only apply the given function to all elements in this
* plan, also considering all the plans in its (nested) subqueries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2613,6 +2613,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY =
buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery")
.internal()
.doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.")
.version("3.2.0")
.booleanConf
.createWithDefault(true)

val TOP_K_SORT_FALLBACK_THRESHOLD =
buildConf("spark.sql.execution.topKSortFallbackThreshold")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()
val z = AttributeReference("z", IntegerType)()
val t0 = OneRowRelation()
val testRelation = LocalRelation(a, b, c)
val testRelation2 = LocalRelation(x, y, z)

Expand Down Expand Up @@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest {

test("correlated values in project") {
val outerPlan = testRelation2
val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation())
val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation()))
val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0)
val correctAnswer = Project(
Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}

test("correlated values in project with alias") {
val outerPlan = testRelation2
val innerPlan =
Project(Seq(OuterReference(x), 'y1, 'sum),
Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum),
Project(Seq(
OuterReference(x),
OuterReference(y).as("y1"),
Add(OuterReference(x), OuterReference(y)).as("sum")),
testRelation)).analyze
val correctAnswer =
Project(Seq(x, 'y1, 'sum, y),
Project(Seq(x, y.as("y1"), (x + y).as("sum"), y),
Project(Seq(x.as("x1"), 'y1, 'sum, x, y),
Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y),
DomainJoin(Seq(x, y), testRelation))).analyze
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}
Expand All @@ -228,28 +230,28 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val outerPlan = testRelation2
val innerPlan =
Project(
Seq(OuterReference(x)),
Seq(OuterReference(x).as("x1")),
Filter(
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
testRelation
)
)
val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation))
val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation))
check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c))
}

test("correlated values in project without correlated equality conditions in filter") {
val outerPlan = testRelation2
val innerPlan =
Project(
Seq(OuterReference(y)),
Seq(OuterReference(y).as("y1")),
Filter(
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
testRelation
)
)
val correctAnswer =
Project(Seq(y, a, c),
Project(Seq(y.as("y1"), y, a, c),
Filter(b === 1,
DomainJoin(Seq(y), testRelation)
)
Expand Down

0 comments on commit de8e4be

Please sign in to comment.