Skip to content

Commit

Permalink
[SPARK-32638][SQL] Corrects references when adding aliases in WidenSe…
Browse files Browse the repository at this point in the history
…tOperationTypes

### What changes were proposed in this pull request?

This PR intends to fix a bug where references can be missing when adding aliases to widen data types in `WidenSetOperationTypes`. For example,
```
CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v);
SELECT t.v FROM (
  SELECT v FROM t3
  UNION ALL
  SELECT v + v AS v FROM t3
) t;

org.apache.spark.sql.AnalysisException: Resolved attribute(s) v#1 missing from v#3 in operator !Project [v#1]. Attribute(s) with the same name appear in the operation: v. Please check if the right attribute(s) are used.;;
!Project [v#1]  <------ the reference got missing
+- SubqueryAlias t
   +- Union
      :- Project [cast(v#1 as decimal(11,0)) AS v#3]
      :  +- Project [v#1]
      :     +- SubqueryAlias t3
      :        +- SubqueryAlias tbl
      :           +- LocalRelation [v#1]
      +- Project [v#2]
         +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2]
            +- SubqueryAlias t3
               +- SubqueryAlias tbl
                  +- LocalRelation [v#1]
```
In the case, `WidenSetOperationTypes` added the alias `cast(v#1 as decimal(11,0)) AS v#3`, then the reference in the top `Project` got missing. This PR correct the reference (`exprId` and widen `dataType`) after adding aliases in the rule.

### Why are the changes needed?

bugfixes

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit tests

Closes #29485 from maropu/SPARK-32638.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
maropu authored and cloud-fan committed Sep 3, 2020
1 parent ffd5227 commit a6114d8
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,127 @@ object AnalysisContext {
}
}

object Analyzer {

/**
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
* This method also updates all the related references in the `plan` accordingly.
*
* @param plan to rewrite
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
* @return a rewritten plan and updated references related to a root node of
* the given `plan` for rewriting it.
*/
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (plan.resolved) {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}

assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
attrMapping ++= plan.output.zip(p.output)
.filter { case (a1, a2) => a1.exprId != a2.exprId }
p -> attrMapping
} else {
// Just passes through unresolved nodes
plan.mapChildren {
rewritePlan(_, rewritePlanMap)._1
} -> Nil
}
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
*/
private def updateOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(updateAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
}
}
}
}
}

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
Expand Down Expand Up @@ -1255,109 +1376,7 @@ class Analyzer(
if (conflictPlans.isEmpty) {
right
} else {
rewritePlan(right, conflictPlans.toMap)._1
}
}

private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (conflictPlanMap.contains(plan)) {
// If the plan is the one that conflict the with left one, we'd
// just replace it with the new plan and collect the rewrite
// attributes for the parent node.
val newRelation = conflictPlanMap(plan)
newRelation -> plan.output.zip(newRelation.output)
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newPlan = plan.mapChildren { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

if (attrMapping.isEmpty) {
newPlan -> attrMapping.toSeq
} else {
assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")
val attributeRewrites = AttributeMap(attrMapping.toSeq)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan.transformExpressions {
case a: Attribute =>
dedupAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
} -> attrMapping.toSeq
}
}
}

private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have been de-duplicated and the function below updates the
* outer references to refer to the de-duplicated attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
*/
private def dedupOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(dedupAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
}
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,29 +326,53 @@ object TypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
object WidenSetOperationTypes extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Except(newChildren.head, newChildren.last, isAll)

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Intersect(newChildren.head, newChildren.last, isAll)

case s: Union if s.childrenResolved && !s.byName &&
object WidenSetOperationTypes extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan resolveOperatorsUp {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Except(newChildren.head._1, newChildren.last._1, isAll)
} else {
s
}

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Intersect(newChildren.head._1, newChildren.last._1, isAll)
} else {
s
}

case s: Union if s.childrenResolved && !s.byName &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.copy(children = newChildren)
val newChildren = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
s.copy(children = newChildren.map(_._1))
} else {
s
}
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
plan
}
}

/** Build new children with the widest types for each attribute among all the children */
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
: Seq[(LogicalPlan, LogicalPlan)] = {
require(children.forall(_.output.length == children.head.output.length))

// Get a sequence of data types, each of which is the widest type of this specific attribute
Expand All @@ -360,8 +384,7 @@ object TypeCoercion {
// Add an extra Project if the targetTypes are different from the original types.
children.map(widenTypes(_, targetTypes))
} else {
// Unable to find a target type to widen, then just return the original set.
children
Nil
}
}

Expand All @@ -385,12 +408,16 @@ object TypeCoercion {
}

/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
: (LogicalPlan, LogicalPlan) = {
val casted = plan.output.zip(targetTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
case (e, dt) if e.dataType != dt =>
val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId)
alias -> alias.newInstance()
case (e, _) =>
e -> e
}.unzip
Project(casted._1, plan) -> Project(casted._2, plan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ import java.sql.Timestamp

import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

class TypeCoercionSuite extends AnalysisTest {
import TypeCoercionSuite._
Expand Down Expand Up @@ -1417,6 +1416,20 @@ class TypeCoercionSuite extends AnalysisTest {
}
}

test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
val p1 = t1.select(t1.output.head)
val p2 = t2.select(t2.output.head)
val union = p1.union(p2)
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
assert(wp1.isInstanceOf[Project])
assert(wp1.missingInput.isEmpty)
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
assert(wp2.isInstanceOf[Aggregate])
assert(wp2.missingInput.isEmpty)
}

/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.
Expand Down
Loading

0 comments on commit a6114d8

Please sign in to comment.