Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -764,22 +764,22 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
result.asInstanceOf[A]
}

def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
}
newFirstChild +: newOtherChildren
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAllPatterns(UNION, PROJECT)) {

// Push down deterministic projection through UNION ALL
case p @ Project(projectList, u: Union) =>
assert(u.children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
}
u.copy(children = newFirstChild +: newOtherChildren)
} else {
p
}
case Project(projectList, u: Union)
if projectList.forall(_.deterministic) && u.children.nonEmpty =>
u.copy(children = pushProjectionThroughUnion(projectList, u))
}
}

Expand Down Expand Up @@ -1006,7 +1006,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
}.isEmpty)
}

private def buildCleanedProjectList(
def buildCleanedProjectList(
upper: Seq[NamedExpression],
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
val aliases = getAliasMap(lower)
Expand Down Expand Up @@ -1300,6 +1300,9 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
import CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
import PushProjectionThroughUnion.pushProjectionThroughUnion

def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
case u: Union => flattenUnion(u, false)
Expand All @@ -1321,6 +1324,10 @@ object CombineUnions extends Rule[LogicalPlan] {
// rules (by position and by name) could cause incorrect results.
while (stack.nonEmpty) {
stack.pop() match {
case p1 @ Project(_, p2: Project)
if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline = false) =>
val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList)
stack.pushAll(Seq(p2.copy(projectList = newProjectList)))
case Distinct(Union(children, byName, allowMissingCol))
if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
Expand All @@ -1332,6 +1339,20 @@ object CombineUnions extends Rule[LogicalPlan] {
case Union(children, byName, allowMissingCol)
if byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
// Push down projection through Union and then push pushed plan to Stack if
// there is a Project.
case Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol)))
if projectList.forall(_.deterministic) && children.nonEmpty &&
flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName &&
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case Project(projectList, u @ Union(children, byName, allowMissingCol))
if projectList.forall(_.deterministic) && children.nonEmpty &&
byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case child =>
flattened += child
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanO
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.{BooleanType, DecimalType}

class SetOperationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down Expand Up @@ -328,4 +328,66 @@ class SetOperationSuite extends PlanTest {
Union(testRelation :: testRelation :: testRelation :: testRelation :: Nil, true, false)
comparePlans(unionOptimized2, unionCorrectAnswer2, false)
}

test("SPARK-37915: combine unions if there is a project between them") {
val relation1 = LocalRelation('a.decimal(18, 1), 'b.int)
val relation2 = LocalRelation('a.decimal(18, 2), 'b.int)
val relation3 = LocalRelation('a.decimal(18, 3), 'b.int)
val relation4 = LocalRelation('a.decimal(18, 4), 'b.int)
val relation5 = LocalRelation('a.decimal(18, 5), 'b.int)

val optimizedRelation1 = relation1.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3))
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation2 = relation2.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3))
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation3 = relation3.select('a.cast(DecimalType(20, 3))
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation4 = relation4
.select('a.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
val optimizedRelation5 = relation5.select('a.cast(DecimalType(22, 5)).as("a"), 'b)

// SQL UNION ALL
comparePlans(
Optimize.execute(relation1.union(relation2)
.union(relation3).union(relation4).union(relation5).analyze),
Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5)).analyze)

// SQL UNION
comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).analyze)

// Deduplicate
comparePlans(
Optimize.execute(relation1.union(relation2).deduplicate('a, 'b).union(relation3)
.deduplicate('a, 'b).union(relation4).deduplicate('a, 'b).union(relation5)
.deduplicate('a, 'b).analyze),
Deduplicate(
Seq('a, 'b),
Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).analyze)

// Other cases
comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).select('a % 2).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).select('a % 2).analyze)

comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).select('a + 'b).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).select('a + 'b).analyze)

comparePlans(
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
.union(relation3)).union(relation4)).union(relation5)).select('a).analyze),
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
optimizedRelation4, optimizedRelation5))).select('a).analyze)

}
}