Skip to content

Commit

Permalink
[SPARK-32638][SQL][FOLLOWUP] Move the plan rewriting methods to Query…
Browse files Browse the repository at this point in the history
…Plan

### What changes were proposed in this pull request?

This is a followup of #29485

It moves the plan rewriting methods from `Analyzer` to `QueryPlan`, so that it can work with `SparkPlan` as well. This PR also does an improvement to support a corner case (The attribute to be replace stays together with an unresolved attribute), and make it more general, so that `WidenSetOperationTypes` can rewrite the plan in one shot like before.

### Why are the changes needed?

Code cleanup and generalize.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing test

Closes #29643 from cloud-fan/cleanup.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
cloud-fan authored and maropu committed Sep 8, 2020
1 parent 954cd9f commit 117a6f1
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,127 +123,6 @@ object AnalysisContext {
}
}

object Analyzer {

/**
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
* This method also updates all the related references in the `plan` accordingly.
*
* @param plan to rewrite
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
* @return a rewritten plan and updated references related to a root node of
* the given `plan` for rewriting it.
*/
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (plan.resolved) {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}

assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
attrMapping ++= plan.output.zip(p.output)
.filter { case (a1, a2) => a1.exprId != a2.exprId }
p -> attrMapping
} else {
// Just passes through unresolved nodes
plan.mapChildren {
rewritePlan(_, rewritePlanMap)._1
} -> Nil
}
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
*/
private def updateOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(updateAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
}
}
}
}
}

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
Expand Down Expand Up @@ -1376,7 +1255,14 @@ class Analyzer(
if (conflictPlans.isEmpty) {
right
} else {
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
val planMapping = conflictPlans.toMap
right.transformUpWithNewOutput {
case oldPlan =>
val newPlanOpt = planMapping.get(oldPlan)
newPlanOpt.map { newPlan =>
newPlan -> oldPlan.output.zip(newPlan.output)
}.getOrElse(oldPlan -> Nil)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,50 +329,43 @@ object TypeCoercion {
object WidenSetOperationTypes extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan resolveOperatorsUp {
plan resolveOperatorsUpWithNewOutput {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Except(newChildren.head._1, newChildren.last._1, isAll)
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.isEmpty) {
s -> Nil
} else {
s
assert(newChildren.length == 2)
val attrMapping = left.output.zip(newChildren.head.output)
Except(newChildren.head, newChildren.last, isAll) -> attrMapping
}

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Intersect(newChildren.head._1, newChildren.last._1, isAll)
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.isEmpty) {
s -> Nil
} else {
s
assert(newChildren.length == 2)
val attrMapping = left.output.zip(newChildren.head.output)
Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
}

case s: Union if s.childrenResolved && !s.byName &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
s.copy(children = newChildren.map(_._1))
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.isEmpty) {
s -> Nil
} else {
s
val attrMapping = s.children.head.output.zip(newChildren.head.output)
s.copy(children = newChildren) -> attrMapping
}
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
plan
}
}

/** Build new children with the widest types for each attribute among all the children */
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
: Seq[(LogicalPlan, LogicalPlan)] = {
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))

// Get a sequence of data types, each of which is the widest type of this specific attribute
Expand Down Expand Up @@ -408,16 +401,13 @@ object TypeCoercion {
}

/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
: (LogicalPlan, LogicalPlan) = {
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
case (e, dt) if e.dataType != dt =>
val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId)
alias -> alias.newInstance()
case (e, _) =>
e -> e
}.unzip
Project(casted._1, plan) -> Project(casted._2, plan)
Alias(Cast(e, dt, Some(SQLConf.get.sessionLocalTimeZone)), e.name)()
case (e, _) => e
}
Project(casted, plan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag}
Expand Down Expand Up @@ -168,6 +170,89 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}.toSeq
}

/**
* A variant of `transformUp`, which takes care of the case that the rule replaces a plan node
* with a new one that has different output expr IDs, by updating the attribute references in
* the parent nodes accordingly.
*
* @param rule the function to transform plan nodes, and return new nodes with attributes mapping
* from old attributes to new attributes. The attribute mapping will be used to
* rewrite attribute references in the parent nodes.
* @param skipCond a boolean condition to indicate if we can skip transforming a plan node to save
* time.
*/
def transformUpWithNewOutput(
rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])],
skipCond: PlanType => Boolean = _ => false): PlanType = {
def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = {
if (skipCond(plan)) {
plan -> Nil
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
var newPlan = plan.mapChildren { child =>
val (newChild, childAttrMapping) = rewrite(child)
attrMapping ++= childAttrMapping
newChild
}

val attrMappingForCurrentPlan = attrMapping.filter {
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
// current `plan`, so the `oldAttr` must be part of `plan.references`.
case (oldAttr, _) => plan.references.contains(oldAttr)
}

val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil)
}
newPlan = planAfterRule

if (attrMappingForCurrentPlan.nonEmpty) {
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan = newPlan.transformExpressions {
case a: AttributeReference =>
updateAttr(a, attributeRewrites)
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites))
}
}

attrMapping ++= newAttrMapping.filter {
case (a1, a2) => a1.exprId != a2.exprId
}
newPlan -> attrMapping
}
}
rewrite(this)._1
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*/
private def updateOuterReferencesInSubquery(
plan: PlanType,
attrMap: AttributeMap[Attribute]): PlanType = {
plan.transformDown { case currentFragment =>
currentFragment.transformExpressions {
case OuterReference(a: AttributeReference) =>
OuterReference(updateAttr(a, attrMap))
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap))
}
}
}

lazy val schema: StructType = StructType.fromAttributes(output)

/** Returns the output schema in the tree format. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.CheckAnalysis
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -120,6 +120,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
}
}

/**
* A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan.
*/
def resolveOperatorsUpWithNewOutput(
rule: PartialFunction[LogicalPlan, (LogicalPlan, Seq[(Attribute, Attribute)])])
: LogicalPlan = {
if (!analyzed) {
transformUpWithNewOutput(rule, skipCond = _.analyzed)
} else {
self
}
}

/**
* Recursively transforms the expressions of a tree, skipping nodes that have already
* been analyzed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,13 @@ class TypeCoercionSuite extends AnalysisTest {
test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
val p1 = t1.select(t1.output.head)
val p2 = t2.select(t2.output.head)
val p1 = t1.select(t1.output.head).as("p1")
val p2 = t2.select(t2.output.head).as("p2")
val union = p1.union(p2)
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v"))
assert(wp1.isInstanceOf[Project])
assert(wp1.missingInput.isEmpty)
// The attribute `p1.output.head` should be replaced in the root `Project`.
assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty))
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
assert(wp2.isInstanceOf[Aggregate])
assert(wp2.missingInput.isEmpty)
Expand Down

0 comments on commit 117a6f1

Please sign in to comment.