Skip to content

Commit

Permalink
SPARK-3462 add tests, handle compatible schema with different aliases…
Browse files Browse the repository at this point in the history
…, per marmbrus feedback
  • Loading branch information
Cody Koeninger committed Sep 11, 2014
1 parent ef47b3b commit 0788691
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,49 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
ColumnPruning) :: Nil
}

/**
* Pushes operations to either side of a Union.
*/
object UnionPushdown extends Rule[LogicalPlan] {
def fixProject(project: Project, child: LogicalPlan): Project = {
val pl = project.projectList.map(p => child.output.find { c =>
c.name == p.name && c.qualifiers == p.qualifiers
}.getOrElse(p))
Project(pl, child)

/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
def buildRewrites(union: Union): AttributeMap[Attribute] = {
assert(union.left.output.size == union.right.output.size)

AttributeMap(union.left.output.zip(union.right.output))
}

def fixFilter(filter: Filter, child: LogicalPlan): Filter = {
val cond = filter.condition.transform {
case a: AttributeReference =>
child.output.find { c =>
c.name == a.name && c.qualifiers == a.qualifiers
}.getOrElse(a)
/**
* Rewrites an expression so that it can be pushed to the right side of a Union operator.
* This method relies on the fact that the output attributes of a union are always equal
* to the left child's output.
*/
def pushToRight[A <: Expression](e: A, union: Union, rewrites: AttributeMap[Attribute]): A = {
val result = e transform {
case a: Attribute => rewrites(a)
}
Filter(cond, child)

// We must promise the compiler that we did not discard the names in the case of project
// expressions. This is safe since the only transformation is from Attribute => Attribute.
result.asInstanceOf[A]
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down filter into union
case f @ Filter(_, Union(left, right)) =>
Union(fixFilter(f, left), fixFilter(f, right))
case Filter(condition, u @ Union(left, right)) =>
val rewrites = buildRewrites(u)
Union(
Filter(condition, left),
Filter(pushToRight(condition, u, rewrites), right))

// Push down projection into union
case p @ Project(_, Union(left, right)) =>
Union(fixProject(p, left), fixProject(p, right))
case Project(projectList, u @ Union(left, right)) =>
val rewrites = buildRewrites(u)
Union(
Project(projectList, left),
Project(projectList.map(pushToRight(_, u, rewrites)), right))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class UnionPushdownSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateAnalysisOperators) ::
Batch("Union Pushdown", Once,
UnionPushdown) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
val testUnion = Union(testRelation, testRelation2)

test("union: filter to each side") {
val query = testUnion.where('a === 1)

val optimized = Optimize(query.analyze)

val correctAnswer =
Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze

comparePlans(optimized, correctAnswer)
}

test("union: project to each side") {
val query = testUnion.select('b)

val optimized = Optimize(query.analyze)

val correctAnswer =
Union(testRelation.select('b), testRelation2.select('e)).analyze

comparePlans(optimized, correctAnswer)
}
}

0 comments on commit 0788691

Please sign in to comment.