Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32638][SQL] Corrects references when adding aliases in WidenSetOperationTypes #29485

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,127 @@ object AnalysisContext {
}
}

object Analyzer {

/**
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
* This method also updates all the related references in the `plan` accordingly.
*
* @param plan to rewrite
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
* @return a rewritten plan and updated references related to a root node of
* the given `plan` for rewriting it.
*/
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
Copy link
Member Author

@maropu maropu Aug 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the existing rewritePlan a bit, then just reused it for WidenSetOperationTypes . Does this udpate satisfy your intention? #29485 (comment)

: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (plan.resolved) {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}

assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
attrMapping ++= plan.output.zip(p.output)
.filter { case (a1, a2) => a1.exprId != a2.exprId }
p -> attrMapping
} else {
// Just passes through unresolved nodes
plan.mapChildren {
rewritePlan(_, rewritePlanMap)._1
} -> Nil
}
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
*/
private def updateOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(updateAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
}
}
}
}
}

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
Expand Down Expand Up @@ -1251,109 +1372,7 @@ class Analyzer(
if (conflictPlans.isEmpty) {
right
} else {
rewritePlan(right, conflictPlans.toMap)._1
}
}

private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (conflictPlanMap.contains(plan)) {
// If the plan is the one that conflict the with left one, we'd
// just replace it with the new plan and collect the rewrite
// attributes for the parent node.
val newRelation = conflictPlanMap(plan)
newRelation -> plan.output.zip(newRelation.output)
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newPlan = plan.mapChildren { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

if (attrMapping.isEmpty) {
newPlan -> attrMapping.toSeq
} else {
assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")
val attributeRewrites = AttributeMap(attrMapping.toSeq)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan.transformExpressions {
case a: Attribute =>
dedupAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
} -> attrMapping.toSeq
}
}
}

private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have been de-duplicated and the function below updates the
* outer references to refer to the de-duplicated attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
*/
private def dedupOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(dedupAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
}
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,29 +326,53 @@ object TypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
object WidenSetOperationTypes extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Except(newChildren.head, newChildren.last, isAll)

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Intersect(newChildren.head, newChildren.last, isAll)

case s: Union if s.childrenResolved && !s.byName &&
object WidenSetOperationTypes extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan resolveOperatorsUp {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Except(newChildren.head._1, newChildren.last._1, isAll)
} else {
s
}

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Intersect(newChildren.head._1, newChildren.last._1, isAll)
} else {
s
}

case s: Union if s.childrenResolved && !s.byName &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.copy(children = newChildren)
val newChildren = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
s.copy(children = newChildren.map(_._1))
} else {
s
}
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
plan
}
}

/** Build new children with the widest types for each attribute among all the children */
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
: Seq[(LogicalPlan, LogicalPlan)] = {
require(children.forall(_.output.length == children.head.output.length))

// Get a sequence of data types, each of which is the widest type of this specific attribute
Expand All @@ -360,8 +384,7 @@ object TypeCoercion {
// Add an extra Project if the targetTypes are different from the original types.
children.map(widenTypes(_, targetTypes))
} else {
// Unable to find a target type to widen, then just return the original set.
children
Nil
}
}

Expand All @@ -385,12 +408,16 @@ object TypeCoercion {
}

/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
: (LogicalPlan, LogicalPlan) = {
val casted = plan.output.zip(targetTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
case (e, dt) if e.dataType != dt =>
val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId)
alias -> alias.newInstance()
case (e, _) =>
e -> e
}.unzip
Project(casted._1, plan) -> Project(casted._2, plan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are we doing here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generates a rewrite map used for Analyzer.rewritePlan. The rewritePlan assumes a plan structure is the same before/after plan rewriting, so this WidenSetOperationTypes rule does two-phase transformation now as follows;

### Input Plan (Query described in the PR description) ###
Project [v#1]
+- SubqueryAlias t
   +- Union
      :+- Project [v#1]
      :   +- SubqueryAlias t3
      :      ...
      +- Project [v#2]
         +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2]
            +- SubqueryAlias t3
               ...

### Phase-1 (Adds Project, but not update ExprId) ###
Project [v#1]
+- SubqueryAlias t
   +- Union
      :- Project [cast(v#1 as decimal(11,0)) AS v#1] <--- !!!Adds Project to widen a type!!!
      :  +- Project [v#1]
      :     +- SubqueryAlias t3
      :        ...
      +- Project [v#2]
         +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2]
            ...

### Phase-2 ###
// Analyzer.rewritePlan updates ExprIds based on a rewrite map:
// `Project [cast(v#1 as decimal(11,0)) AS v#1]` => Project [cast(v#1 as decimal(11,0)) AS v#3]
Project [v#3] <--- !!!Updates ExprId!!!
+- SubqueryAlias t
   +- Union
      :- Project [cast(v#1 as decimal(11,0)) AS v#3] <--- !!!Updates ExprId!!!
      :  +- Project [v#1]
      :     +- SubqueryAlias t3
      :        ...
      +- Project [v#2]
         +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2]
            ...

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ import java.sql.Timestamp

import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

class TypeCoercionSuite extends AnalysisTest {
import TypeCoercionSuite._
Expand Down Expand Up @@ -1417,6 +1416,20 @@ class TypeCoercionSuite extends AnalysisTest {
}
}

test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
val p1 = t1.select(t1.output.head)
val p2 = t2.select(t2.output.head)
val union = p1.union(p2)
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
assert(wp1.isInstanceOf[Project])
assert(wp1.missingInput.isEmpty)
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
assert(wp2.isInstanceOf[Aggregate])
assert(wp2.missingInput.isEmpty)
}

/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.
Expand Down
Loading