Skip to content

Commit

Permalink
[SPARK-32638][SQL][3.0] Corrects references when adding aliases in Wi…
Browse files Browse the repository at this point in the history
…denSetOperationTypes

### What changes were proposed in this pull request?

This PR intends to fix a bug where references can be missing when adding aliases to widen data types in `WidenSetOperationTypes`. For example,
```
CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v);
SELECT t.v FROM (
  SELECT v FROM t3
  UNION ALL
  SELECT v + v AS v FROM t3
) t;

org.apache.spark.sql.AnalysisException: Resolved attribute(s) v#1 missing from v#3 in operator !Project [v#1]. Attribute(s) with the same name appear in the operation: v. Please check if the right attribute(s) are used.;;
!Project [v#1]  <------ the reference got missing
+- SubqueryAlias t
   +- Union
      :- Project [cast(v#1 as decimal(11,0)) AS v#3]
      :  +- Project [v#1]
      :     +- SubqueryAlias t3
      :        +- SubqueryAlias tbl
      :           +- LocalRelation [v#1]
      +- Project [v#2]
         +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2]
            +- SubqueryAlias t3
               +- SubqueryAlias tbl
                  +- LocalRelation [v#1]
```
In the case, `WidenSetOperationTypes` added the alias `cast(v#1 as decimal(11,0)) AS v#3`, then the reference in the top `Project` got missing. This PR correct the reference (`exprId` and widen `dataType`) after adding aliases in the rule.

This backport for 3.0 comes from #29485 and #29643

### Why are the changes needed?

bugfixes

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit tests

Closes #29680 from maropu/SPARK-32638-BRANCH3.0.

Lead-authored-by: Wenchen Fan <wenchen@databricks.com>
Co-authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
cloud-fan and maropu committed Sep 8, 2020
1 parent 8c0b9cb commit 3f20f14
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1239,108 +1239,13 @@ class Analyzer(
if (conflictPlans.isEmpty) {
right
} else {
rewritePlan(right, conflictPlans.toMap)._1
}
}

private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (conflictPlanMap.contains(plan)) {
// If the plan is the one that conflict the with left one, we'd
// just replace it with the new plan and collect the rewrite
// attributes for the parent node.
val newRelation = conflictPlanMap(plan)
newRelation -> plan.output.zip(newRelation.output)
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newPlan = plan.mapChildren { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

if (attrMapping.isEmpty) {
newPlan -> attrMapping
} else {
assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")
val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan.transformExpressions {
case a: Attribute =>
dedupAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
} -> attrMapping
}
}
}

private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have been de-duplicated and the function below updates the
* outer references to refer to the de-duplicated attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
*/
private def dedupOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(dedupAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
val planMapping = conflictPlans.toMap
right.transformUpWithNewOutput {
case oldPlan =>
val newPlanOpt = planMapping.get(oldPlan)
newPlanOpt.map { newPlan =>
newPlan -> oldPlan.output.zip(newPlan.output)
}.getOrElse(oldPlan -> Nil)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,25 +326,42 @@ object TypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
object WidenSetOperationTypes extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Except(newChildren.head, newChildren.last, isAll)

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Intersect(newChildren.head, newChildren.last, isAll)

case s: Union if s.childrenResolved &&
object WidenSetOperationTypes extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
plan resolveOperatorsUpWithNewOutput {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.isEmpty) {
s -> Nil
} else {
assert(newChildren.length == 2)
val attrMapping = left.output.zip(newChildren.head.output)
Except(newChildren.head, newChildren.last, isAll) -> attrMapping
}

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.isEmpty) {
s -> Nil
} else {
assert(newChildren.length == 2)
val attrMapping = left.output.zip(newChildren.head.output)
Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
}

case s: Union if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.makeCopy(Array(newChildren))
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.isEmpty) {
s -> Nil
} else {
val attrMapping = s.children.head.output.zip(newChildren.head.output)
s.copy(children = newChildren) -> attrMapping
}
}
}

/** Build new children with the widest types for each attribute among all the children */
Expand All @@ -360,8 +377,7 @@ object TypeCoercion {
// Add an extra Project if the targetTypes are different from the original types.
children.map(widenTypes(_, targetTypes))
} else {
// Unable to find a target type to widen, then just return the original set.
children
Nil
}
}

Expand All @@ -387,7 +403,8 @@ object TypeCoercion {
/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, dt) if e.dataType != dt =>
Alias(Cast(e, dt, Some(SQLConf.get.sessionLocalTimeZone)), e.name)()
case (e, _) => e
}
Project(casted, plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag}
Expand Down Expand Up @@ -168,6 +170,89 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}.toSeq
}

/**
* A variant of `transformUp`, which takes care of the case that the rule replaces a plan node
* with a new one that has different output expr IDs, by updating the attribute references in
* the parent nodes accordingly.
*
* @param rule the function to transform plan nodes, and return new nodes with attributes mapping
* from old attributes to new attributes. The attribute mapping will be used to
* rewrite attribute references in the parent nodes.
* @param skipCond a boolean condition to indicate if we can skip transforming a plan node to save
* time.
*/
def transformUpWithNewOutput(
rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])],
skipCond: PlanType => Boolean = _ => false): PlanType = {
def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = {
if (skipCond(plan)) {
plan -> Nil
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
var newPlan = plan.mapChildren { child =>
val (newChild, childAttrMapping) = rewrite(child)
attrMapping ++= childAttrMapping
newChild
}

val attrMappingForCurrentPlan = attrMapping.filter {
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
// current `plan`, so the `oldAttr` must be part of `plan.references`.
case (oldAttr, _) => plan.references.contains(oldAttr)
}

val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil)
}
newPlan = planAfterRule

if (attrMappingForCurrentPlan.nonEmpty) {
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan = newPlan.transformExpressions {
case a: AttributeReference =>
updateAttr(a, attributeRewrites)
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites))
}
}

attrMapping ++= newAttrMapping.filter {
case (a1, a2) => a1.exprId != a2.exprId
}
newPlan -> attrMapping
}
}
rewrite(this)._1
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*/
private def updateOuterReferencesInSubquery(
plan: PlanType,
attrMap: AttributeMap[Attribute]): PlanType = {
plan.transformDown { case currentFragment =>
currentFragment.transformExpressions {
case OuterReference(a: AttributeReference) =>
OuterReference(updateAttr(a, attrMap))
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap))
}
}
}

lazy val schema: StructType = StructType.fromAttributes(output)

/** Returns the output schema in the tree format. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.CheckAnalysis
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -120,6 +120,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
}
}

/**
* A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan.
*/
def resolveOperatorsUpWithNewOutput(
rule: PartialFunction[LogicalPlan, (LogicalPlan, Seq[(Attribute, Attribute)])])
: LogicalPlan = {
if (!analyzed) {
transformUpWithNewOutput(rule, skipCond = _.analyzed)
} else {
self
}
}

/**
* Recursively transforms the expressions of a tree, skipping nodes that have already
* been analyzed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ import java.sql.Timestamp

import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

class TypeCoercionSuite extends AnalysisTest {
import TypeCoercionSuite._
Expand Down Expand Up @@ -1417,6 +1416,21 @@ class TypeCoercionSuite extends AnalysisTest {
}
}

test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
val p1 = t1.select(t1.output.head).as("p1")
val p2 = t2.select(t2.output.head).as("p2")
val union = p1.union(p2)
val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v"))
assert(wp1.isInstanceOf[Project])
// The attribute `p1.output.head` should be replaced in the root `Project`.
assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty))
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
assert(wp2.isInstanceOf[Aggregate])
assert(wp2.missingInput.isEmpty)
}

/**
* There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early.
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/except.sql
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,22 @@ FROM t1
WHERE t1.v >= (SELECT min(t2.v)
FROM t2
WHERE t2.k = t1.k);

-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes
CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v);
SELECT t.v FROM (
SELECT v FROM t3
EXCEPT
SELECT v + v AS v FROM t3
) t;

SELECT SUM(t.v) FROM (
SELECT v FROM t3
EXCEPT
SELECT v + v AS v FROM t3
) t;

-- Clean-up
DROP VIEW IF EXISTS t1;
DROP VIEW IF EXISTS t2;
DROP VIEW IF EXISTS t3;
Loading

0 comments on commit 3f20f14

Please sign in to comment.