From b5a4efa1f5f578e4b5001bed5e3150fa092fe3d1 Mon Sep 17 00:00:00 2001 From: kai Date: Wed, 1 Jul 2015 10:25:18 -0700 Subject: [PATCH 1/4] (1) Add broadcast hash outer join, (2) Fix SparkPlanTest --- .../spark/sql/execution/SparkStrategies.scala | 15 ++- .../joins/BroadcastHashOuterJoin.scala | 120 ++++++++++++++++++ .../sql/execution/joins/HashOuterJoin.scala | 95 ++++---------- .../sql/execution/joins/HashedRelation.scala | 7 + .../joins/ShuffledHashOuterJoin.scala | 85 +++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 39 +++++- .../spark/sql/execution/SparkPlanTest.scala | 103 +++++++++++---- .../sql/execution/joins/OuterJoinSuite.scala | 88 +++++++++++++ 8 files changed, 450 insertions(+), 102 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d817586..1386531353005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -118,8 +118,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.HashOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + joinType match { + case LeftOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case RightOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case _ => + joins.ShuffledHashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + } case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala new file mode 100644 index 0000000000000..372cd22d80e27 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils + +import scala.concurrent._ +import scala.concurrent.duration._ + +/** + * :: DeveloperApi :: + * Performs a outer hash join for two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + val timeout = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + private[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + private[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + @transient + private val broadcastFuture = future { + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() + // buildHashTable uses code-generated rows as keys, which are not serializable + val hashed = new GeneralHashedRelation( + buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))) + sparkContext.broadcast(hashed) + }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + val broadcastRelation = Await.result(broadcastFuture, timeout) + + streamedPlan.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + + joinType match { + case LeftOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + } + } +} + +object BroadcastHashOuterJoin { + + private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e41538ec1fc1a..886b5fa0c5103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,32 +19,25 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} -import org.apache.spark.rdd.RDD - -import scala.collection.JavaConversions._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ @DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { +trait HashOuterJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val joinType: JoinType + val condition: Option[Expression] + val left: SparkPlan + val right: SparkPlan + +override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -52,9 +45,6 @@ case class HashOuterJoin( throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -68,8 +58,8 @@ case class HashOuterJoin( } } - @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow] + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @@ -80,7 +70,7 @@ case class HashOuterJoin( // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. - private[this] def leftOuterIterator( + protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { @@ -89,7 +79,7 @@ case class HashOuterJoin( val temp = rightIter.collect { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -101,18 +91,17 @@ case class HashOuterJoin( ret.iterator } - private[this] def rightOuterIterator( + protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = leftIter.collect { case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy + joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -124,10 +113,9 @@ case class HashOuterJoin( ret.iterator } - private[this] def fullOuterIterator( + protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -171,7 +159,7 @@ case class HashOuterJoin( } } - private[this] def buildHashTable( + protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() @@ -190,43 +178,4 @@ case class HashOuterJoin( hashTable } - - protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - - joinType match { - case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) - leftIter.flatMap( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) - }) - - case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) - rightIter.flatMap ( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) - }) - - case FullOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) - } - - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index e18c817975134..9b114dc8c507b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -32,6 +32,13 @@ import org.apache.spark.util.collection.CompactBuffer private[joins] sealed trait HashedRelation { def get(key: InternalRow): CompactBuffer[InternalRow] + def getOrElse( + key: InternalRow, + default: CompactBuffer[InternalRow]): CompactBuffer[InternalRow] = { + val v = get(key) + if (v eq null) default else v + } + // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala new file mode 100644 index 0000000000000..670666ca0512f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +import scala.collection.JavaConversions._ + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class ShuffledHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[ClusteredDistribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val joinedRow = new JoinedRow() + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + joinType match { + case LeftOuter => + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val keyGenerator = newProjection(leftKeys, left.output) + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val keyGenerator = newProjection(rightKeys, right.output) + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case FullOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), + joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"ShuffledHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 20390a5544304..090c05f87ca61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -45,9 +45,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -81,12 +82,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -133,6 +134,34 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.sql("UNCACHE TABLE testData") } + test("broadcasted hash outer join operator selection") { + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + + ctx.sql("UNCACHE TABLE testData") + } + test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..2de83109e40a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -54,6 +54,37 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row]): Unit = { + checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + checkAnswer(left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -68,15 +99,45 @@ class SparkPlanTest extends SparkFunSuite { * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. */ protected def checkAnswer[A <: Product : TypeTag]( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(input, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(left, right, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[A]): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } + checkAnswer(input, planFunction, expectedRows) } + } /** @@ -92,27 +153,25 @@ object SparkPlanTest { * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ def checkAnswer( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row]): Option[String] = { - val outputPlan = planFunction(input.queryExecution.sparkPlan) + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { - case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap - - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } + val resolvedPlan = TestSQLContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 0000000000000..5707d2fb300ae --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest { + + val left = Seq( + (1, 2.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("shuffled hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } + + test("broadcast hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } +} From dc5127e4eb00cfeafbb3b887001532abcb4b5920 Mon Sep 17 00:00:00 2001 From: kai Date: Wed, 1 Jul 2015 10:43:05 -0700 Subject: [PATCH 2/4] code style fixes --- .../execution/joins/BroadcastHashOuterJoin.scala | 4 ++-- .../execution/joins/ShuffledHashOuterJoin.scala | 4 ++-- .../scala/org/apache/spark/sql/JoinSuite.scala | 3 ++- .../apache/spark/sql/execution/SparkPlanTest.scala | 14 +++++++------- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 372cd22d80e27..0b01027e6fc47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -53,7 +53,7 @@ case class BroadcastHashOuterJoin( } } - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil private[this] lazy val (buildPlan, streamedPlan) = joinType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 670666ca0512f..cfc9c14aaa363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -40,7 +40,7 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 090c05f87ca61..8953889d1fae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -87,7 +87,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 2de83109e40a6..108b1122f7bff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -99,9 +99,9 @@ class SparkPlanTest extends SparkFunSuite { * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. */ protected def checkAnswer[A <: Product : TypeTag]( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) checkAnswer(input, planFunction, expectedRows) } @@ -115,10 +115,10 @@ class SparkPlanTest extends SparkFunSuite { * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. */ protected def checkAnswer[A <: Product : TypeTag]( - left: DataFrame, - right: DataFrame, - planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[A]): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) checkAnswer(left, right, planFunction, expectedRows) } From 14e4bf8184f27a262e0ac0355be090fa0c8c5a3c Mon Sep 17 00:00:00 2001 From: kai Date: Thu, 2 Jul 2015 00:12:21 -0700 Subject: [PATCH 3/4] Use CanBroadcast in broadcast outer join planning --- .../spark/sql/execution/SparkStrategies.scala | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1386531353005..32044989044a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -117,20 +117,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joinType match { - case LeftOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case RightOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case _ => - joins.ShuffledHashOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - } + joins.ShuffledHashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil } From 3742359b674c355da233e4cbeef49b65faa792b6 Mon Sep 17 00:00:00 2001 From: kai Date: Fri, 3 Jul 2015 17:00:07 -0700 Subject: [PATCH 4/4] Fix not-serializable exception for code-generated keys in broadcasted relations --- .../spark/sql/execution/joins/BroadcastHashOuterJoin.scala | 5 +++-- .../apache/spark/sql/execution/joins/HashedRelation.scala | 7 ------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 0b01027e6fc47..5da04c78744d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils +import scala.collection.JavaConversions._ import scala.concurrent._ import scala.concurrent.duration._ @@ -77,8 +78,8 @@ case class BroadcastHashOuterJoin( // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() // buildHashTable uses code-generated rows as keys, which are not serializable - val hashed = new GeneralHashedRelation( - buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))) + val hashed = + buildHashTable(input.iterator, new InterpretedProjection(buildKeys, buildPlan.output)) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 9b114dc8c507b..e18c817975134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -32,13 +32,6 @@ import org.apache.spark.util.collection.CompactBuffer private[joins] sealed trait HashedRelation { def get(key: InternalRow): CompactBuffer[InternalRow] - def getOrElse( - key: InternalRow, - default: CompactBuffer[InternalRow]): CompactBuffer[InternalRow] = { - val v = get(key) - if (v eq null) default else v - } - // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {