Skip to content

Commit

Permalink
[SPARK-32635][SQL] Fix foldable propagation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR rewrites `FoldablePropagation` rule to replace attribute references in a node with foldables coming only from the node's children.

Before this PR in the case of this example (with setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`):
```scala
val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
val aub = a.union(b)
val c = aub.filter($"col1" === "2").cache()
val d = Seq("2").toDF( "col4")
val r = d.join(aub, $"col2" === $"col4").select("col4")
val l = c.select("col2")
val df = l.join(r, $"col2" === $"col4", "LeftOuter")
df.show()
```
foldable propagation happens incorrectly:
```
 Join LeftOuter, (col2#6 = col4#34)                                                              Join LeftOuter, (col2#6 = col4#34)
!:- Project [col2#6]                                                                             :- Project [1 AS col2#6]
 :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)   :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)
 :        +- Union                                                                               :        +- Union
 :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]                                    :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]
 :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))                            :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))
 :           :     +- *(1) LocalTableScan [value#1]                                              :           :     +- *(1) LocalTableScan [value#1]
 :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]                                 :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]
 :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))                          :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))
 :                 +- *(2) LocalTableScan [value#10]                                             :                 +- *(2) LocalTableScan [value#10]
 +- Project [col4#34]                                                                            +- Project [col4#34]
    +- Join Inner, (col2#6 = col4#34)                                                               +- Join Inner, (col2#6 = col4#34)
       :- Project [value#31 AS col4#34]                                                                :- Project [value#31 AS col4#34]
       :  +- LocalRelation [value#31]                                                                  :  +- LocalRelation [value#31]
       +- Project [col2#6]                                                                             +- Project [col2#6]
          +- Union false, false                                                                           +- Union false, false
             :- Project [1 AS col2#6]                                                                        :- Project [1 AS col2#6]
             :  +- LocalRelation [value#1]                                                                   :  +- LocalRelation [value#1]
             +- Project [2 AS col2#15]                                                                       +- Project [2 AS col2#15]
                +- LocalRelation [value#10]                                                                     +- LocalRelation [value#10]

```
and so the result is wrong:
```
+----+----+
|col2|col4|
+----+----+
|   1|null|
+----+----+
```

After this PR foldable propagation will not happen incorrectly and the result is correct:
```
+----+----+
|col2|col4|
+----+----+
|   2|   2|
+----+----+
```

### Why are the changes needed?
To fix a correctness issue.

### Does this PR introduce _any_ user-facing change?
Yes, fixes a correctness issue.

### How was this patch tested?
Existing and new UTs.

Closes #29771 from peter-toth/SPARK-32635-fix-foldable-propagation.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
(cherry picked from commit 4ced588)
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
peter-toth authored and maropu committed Sep 17, 2020
1 parent 17a5195 commit ecc2f5d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}

def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
}

class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}

def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
}

class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,67 +620,90 @@ object NullPropagation extends Rule[LogicalPlan] {
*/
object FoldablePropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
var foldableMap = AttributeMap(plan.flatMap {
case Project(projectList, _) => projectList.collect {
case a: Alias if a.child.foldable => (a.toAttribute, a)
}
case _ => Nil
})
val replaceFoldable: PartialFunction[Expression, Expression] = {
case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
CleanupAliases(propagateFoldables(plan)._1)
}

private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = {
plan match {
case p: Project =>
val (newChild, foldableMap) = propagateFoldables(p.child)
val newProject =
replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], foldableMap)
val newFoldableMap = AttributeMap(newProject.projectList.collect {
case a: Alias if a.child.foldable => (a.toAttribute, a)
})
(newProject, newFoldableMap)

// We can not replace the attributes in `Expand.output`. If there are other non-leaf
// operators that have the `output` field, we should put them here too.
case e: Expand =>
val (newChild, foldableMap) = propagateFoldables(e.child)
val expandWithNewChildren = e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
val newExpand = if (foldableMap.isEmpty) {
expandWithNewChildren
} else {
val newProjections = expandWithNewChildren.projections.map(_.map(_.transform {
case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
}))
if (newProjections == expandWithNewChildren.projections) {
expandWithNewChildren
} else {
expandWithNewChildren.copy(projections = newProjections)
}
}
(newExpand, foldableMap)

case u: UnaryNode if canPropagateFoldables(u) =>
val (newChild, foldableMap) = propagateFoldables(u.child)
val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), foldableMap)
(newU, foldableMap)

// Join derives the output attributes from its child while they are actually not the
// same attributes. For example, the output of outer join is not always picked from its
// children, but can also be null. We should exclude these miss-derived attributes when
// propagating the foldable expressions.
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
// of outer join.
case j: Join =>
val (newChildren, foldableMaps) = j.children.map(propagateFoldables).unzip
val foldableMap = AttributeMap(
foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ _.baseMap.values).toSeq)
val newJoin =
replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap)
val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match {
case _: InnerLike | LeftExistence(_) => Nil
case LeftOuter => newJoin.right.output
case RightOuter => newJoin.left.output
case FullOuter => newJoin.left.output ++ newJoin.right.output
})
val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
case (attr, _) => missDerivedAttrsSet.contains(attr)
}.toSeq)
(newJoin, newFoldableMap)

// For other plans, they are not safe to apply foldable propagation, and they should not
// propagate foldable expressions from children.
case o =>
val newOther = o.mapChildren(propagateFoldables(_)._1)
(newOther, AttributeMap.empty)
}
}

private def replaceFoldable(plan: LogicalPlan, foldableMap: AttributeMap[Alias]): plan.type = {
if (foldableMap.isEmpty) {
plan
} else {
CleanupAliases(plan.transformUp {
// We can only propagate foldables for a subset of unary nodes.
case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) =>
u.transformExpressions(replaceFoldable)

// Join derives the output attributes from its child while they are actually not the
// same attributes. For example, the output of outer join is not always picked from its
// children, but can also be null. We should exclude these miss-derived attributes when
// propagating the foldable expressions.
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
// of outer join.
case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
val newJoin = j.transformExpressions(replaceFoldable)
val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
case _: InnerLike | LeftExistence(_) => Nil
case LeftOuter => right.output
case RightOuter => left.output
case FullOuter => left.output ++ right.output
})
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
case (attr, _) => missDerivedAttrsSet.contains(attr)
}.toSeq)
newJoin

// We can not replace the attributes in `Expand.output`. If there are other non-leaf
// operators that have the `output` field, we should put them here too.
case expand: Expand if foldableMap.nonEmpty =>
expand.copy(projections = expand.projections.map { projection =>
projection.map(_.transform(replaceFoldable))
})

// For other plans, they are not safe to apply foldable propagation, and they should not
// propagate foldable expressions from children.
case other if foldableMap.nonEmpty =>
val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
case (attr, _) => childrenOutputSet.contains(attr)
}.toSeq)
other
})
plan transformExpressions {
case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
}
}
}

/**
* Whitelist of all [[UnaryNode]]s for which allow foldable propagation.
*/
private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
case _: Project => true
// Handling `Project` is moved to `propagateFoldables`.
case _: Filter => true
case _: SubqueryAlias => true
case _: Aggregate => true
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2455,6 +2455,18 @@ class DataFrameSuite extends QueryTest
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
checkAnswer(df.select($"pos" > $"neg"), Row(false))
}

test("SPARK-32635: Replace references with foldables coming only from the node's children") {
val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
val aub = a.union(b)
val c = aub.filter($"col1" === "2").cache()
val d = Seq("2").toDF("col4")
val r = d.join(aub, $"col2" === $"col4").select("col4")
val l = c.select("col2")
val df = l.join(r, $"col2" === $"col4", "LeftOuter")
checkAnswer(df, Row("2", "2"))
}
}

case class GroupByKey(a: Int, b: Int)

0 comments on commit ecc2f5d

Please sign in to comment.