diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d817586..32044989044a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -117,8 +117,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.HashOuterJoin( + joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala new file mode 100644 index 0000000000000..5da04c78744d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils + +import scala.collection.JavaConversions._ +import scala.concurrent._ +import scala.concurrent.duration._ + +/** + * :: DeveloperApi :: + * Performs a outer hash join for two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + val timeout = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + override def requiredChildDistribution: Seq[Distribution] = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + private[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + private[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + @transient + private val broadcastFuture = future { + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() + // buildHashTable uses code-generated rows as keys, which are not serializable + val hashed = + buildHashTable(input.iterator, new InterpretedProjection(buildKeys, buildPlan.output)) + sparkContext.broadcast(hashed) + }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + val broadcastRelation = Await.result(broadcastFuture, timeout) + + streamedPlan.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + + joinType match { + case LeftOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + } + } +} + +object BroadcastHashOuterJoin { + + private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e41538ec1fc1a..886b5fa0c5103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,32 +19,25 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} -import org.apache.spark.rdd.RDD - -import scala.collection.JavaConversions._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ @DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { +trait HashOuterJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val joinType: JoinType + val condition: Option[Expression] + val left: SparkPlan + val right: SparkPlan + +override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -52,9 +45,6 @@ case class HashOuterJoin( throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -68,8 +58,8 @@ case class HashOuterJoin( } } - @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow] + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @@ -80,7 +70,7 @@ case class HashOuterJoin( // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. - private[this] def leftOuterIterator( + protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { @@ -89,7 +79,7 @@ case class HashOuterJoin( val temp = rightIter.collect { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -101,18 +91,17 @@ case class HashOuterJoin( ret.iterator } - private[this] def rightOuterIterator( + protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = leftIter.collect { case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy + joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -124,10 +113,9 @@ case class HashOuterJoin( ret.iterator } - private[this] def fullOuterIterator( + protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -171,7 +159,7 @@ case class HashOuterJoin( } } - private[this] def buildHashTable( + protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() @@ -190,43 +178,4 @@ case class HashOuterJoin( hashTable } - - protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - - joinType match { - case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) - leftIter.flatMap( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) - }) - - case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) - rightIter.flatMap ( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) - }) - - case FullOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) - } - - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala new file mode 100644 index 0000000000000..cfc9c14aaa363 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +import scala.collection.JavaConversions._ + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class ShuffledHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val joinedRow = new JoinedRow() + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + joinType match { + case LeftOuter => + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val keyGenerator = newProjection(leftKeys, left.output) + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val keyGenerator = newProjection(rightKeys, right.output) + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case FullOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), + joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"ShuffledHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 20390a5544304..8953889d1fae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -45,9 +45,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -81,12 +82,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -133,6 +135,34 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.sql("UNCACHE TABLE testData") } + test("broadcasted hash outer join operator selection") { + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + + ctx.sql("UNCACHE TABLE testData") + } + test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..108b1122f7bff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -54,6 +54,37 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row]): Unit = { + checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + checkAnswer(left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -72,11 +103,41 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[A]): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } + checkAnswer(input, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(left, right, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(input, planFunction, expectedRows) } + } /** @@ -92,27 +153,25 @@ object SparkPlanTest { * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ def checkAnswer( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row]): Option[String] = { - val outputPlan = planFunction(input.queryExecution.sparkPlan) + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { - case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap - - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } + val resolvedPlan = TestSQLContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 0000000000000..5707d2fb300ae --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest { + + val left = Seq( + (1, 2.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("shuffled hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } + + test("broadcast hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } +}