Skip to content

Commit

Permalink
[SPARK-37915][SQL] Combine unions if there is a project between them
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This pr makes `CombineUnions` combine unions if there is a project between them. For example:
```scala
spark.range(1).selectExpr("CAST(id AS decimal(18, 1)) AS id").write.saveAsTable("t1")
spark.range(2).selectExpr("CAST(id AS decimal(18, 2)) AS id").write.saveAsTable("t2")
spark.range(3).selectExpr("CAST(id AS decimal(18, 3)) AS id").write.saveAsTable("t3")
spark.range(4).selectExpr("CAST(id AS decimal(18, 4)) AS id").write.saveAsTable("t4")
spark.range(5).selectExpr("CAST(id AS decimal(18, 5)) AS id").write.saveAsTable("t5")

spark.sql("SELECT id FROM t1 UNION SELECT id FROM t2 UNION SELECT id FROM t3 UNION SELECT id FROM t4 UNION SELECT id FROM t5").explain(true)
```

Before this pr:
```
== Optimized Logical Plan ==
Aggregate [id#36], [id#36]
+- Union false, false
   :- Aggregate [id#34], [cast(id#34 as decimal(22,5)) AS id#36]
   :  +- Union false, false
   :     :- Aggregate [id#32], [cast(id#32 as decimal(21,4)) AS id#34]
   :     :  +- Union false, false
   :     :     :- Aggregate [id#30], [cast(id#30 as decimal(20,3)) AS id#32]
   :     :     :  +- Union false, false
   :     :     :     :- Project [cast(id#25 as decimal(19,2)) AS id#30]
   :     :     :     :  +- Relation default.t1[id#25] parquet
   :     :     :     +- Project [cast(id#26 as decimal(19,2)) AS id#31]
   :     :     :        +- Relation default.t2[id#26] parquet
   :     :     +- Project [cast(id#27 as decimal(20,3)) AS id#33]
   :     :        +- Relation default.t3[id#27] parquet
   :     +- Project [cast(id#28 as decimal(21,4)) AS id#35]
   :        +- Relation default.t4[id#28] parquet
   +- Project [cast(id#29 as decimal(22,5)) AS id#37]
      +- Relation default.t5[id#29] parquet
```

After this pr:
```
== Optimized Logical Plan ==
Aggregate [id#36], [id#36]
+- Union false, false
   :- Project [cast(id#25 as decimal(22,5)) AS id#36]
   :  +- Relation default.t1[id#25] parquet
   :- Project [cast(id#26 as decimal(22,5)) AS id#46]
   :  +- Relation default.t2[id#26] parquet
   :- Project [cast(id#27 as decimal(22,5)) AS id#45]
   :  +- Relation default.t3[id#27] parquet
   :- Project [cast(id#28 as decimal(22,5)) AS id#44]
   :  +- Relation default.t4[id#28] parquet
   +- Project [cast(id#29 as decimal(22,5)) AS id#37]
      +- Relation default.t5[id#29] parquet
```

### Why are the changes needed?

Improve query performance by reduce shuffles.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #35214 from wangyum/SPARK-37915.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wangyum authored and cloud-fan committed Jan 25, 2022
1 parent 18f9e7e commit ac2b0df
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -764,22 +764,22 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
result.asInstanceOf[A]
}

def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
}
newFirstChild +: newOtherChildren
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAllPatterns(UNION, PROJECT)) {

// Push down deterministic projection through UNION ALL
case p @ Project(projectList, u: Union) =>
assert(u.children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
}
u.copy(children = newFirstChild +: newOtherChildren)
} else {
p
}
case Project(projectList, u: Union)
if projectList.forall(_.deterministic) && u.children.nonEmpty =>
u.copy(children = pushProjectionThroughUnion(projectList, u))
}
}

Expand Down Expand Up @@ -1006,7 +1006,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
}.isEmpty)
}

private def buildCleanedProjectList(
def buildCleanedProjectList(
upper: Seq[NamedExpression],
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
val aliases = getAliasMap(lower)
Expand Down Expand Up @@ -1300,6 +1300,9 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
import CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
import PushProjectionThroughUnion.pushProjectionThroughUnion

def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
case u: Union => flattenUnion(u, false)
Expand All @@ -1321,6 +1324,10 @@ object CombineUnions extends Rule[LogicalPlan] {
// rules (by position and by name) could cause incorrect results.
while (stack.nonEmpty) {
stack.pop() match {
case p1 @ Project(_, p2: Project)
if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline = false) =>
val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList)
stack.pushAll(Seq(p2.copy(projectList = newProjectList)))
case Distinct(Union(children, byName, allowMissingCol))
if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
Expand All @@ -1332,6 +1339,20 @@ object CombineUnions extends Rule[LogicalPlan] {
case Union(children, byName, allowMissingCol)
if byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
// Push down projection through Union and then push pushed plan to Stack if
// there is a Project.
case Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol)))
if projectList.forall(_.deterministic) && children.nonEmpty &&
flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName &&
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case Project(projectList, u @ Union(children, byName, allowMissingCol))
if projectList.forall(_.deterministic) && children.nonEmpty &&
byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case child =>
flattened += child
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanO
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.{BooleanType, DecimalType}

class SetOperationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down Expand Up @@ -328,4 +328,66 @@ class SetOperationSuite extends PlanTest {
Union(testRelation :: testRelation :: testRelation :: testRelation :: Nil, true, false)
comparePlans(unionOptimized2, unionCorrectAnswer2, false)
}

test("SPARK-37915: combine unions if there is a project between them") {
val relation1 = LocalRelation('a.decimal(18, 1), 'b.int)
val relation2 = LocalRelation('a.decimal(18, 2), 'b.int)
val relation3 = LocalRelation('a.decimal(18, 3), 'b.int)
val relation4 = LocalRelation('a.decimal(18, 4), 'b.int)
val relation5 = LocalRelation('a.decimal(18, 5), 'b.int)

val optimizedRelation1 = relation1.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3))
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation2 = relation2.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3))
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation3 = relation3.select('a.cast(DecimalType(20, 3))
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation4 = relation4
.select('a.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation5 = relation5.select('a.cast(DecimalType(22, 5)).as("a"), 'b)

// SQL UNION ALL
comparePlans(
Optimize.execute(relation1.union(relation2)
.union(relation3).union(relation4).union(relation5).analyze),
Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5)).analyze)

// SQL UNION
comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).analyze)

// Deduplicate
comparePlans(
Optimize.execute(relation1.union(relation2).deduplicate('a, 'b).union(relation3)
.deduplicate('a, 'b).union(relation4).deduplicate('a, 'b).union(relation5)
.deduplicate('a, 'b).analyze),
Deduplicate(
Seq('a, 'b),
Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).analyze)

// Other cases
comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).select('a % 2).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).select('a % 2).analyze)

comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).select('a + 'b).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).select('a + 'b).analyze)

comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).select('a).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).select('a).analyze)

}
}

0 comments on commit ac2b0df

Please sign in to comment.