Skip to content

Commit

Permalink
(1) Add broadcast hash outer join, (2) Fix SparkPlanTest
Browse files Browse the repository at this point in the history
  • Loading branch information
kai committed Jul 1, 2015
1 parent d16a944 commit b5a4efa
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
joinType match {
case LeftOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
case RightOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
case _ =>
joins.ShuffledHashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
}

case _ => Nil
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils

import scala.concurrent._
import scala.concurrent.duration._

/**
* :: DeveloperApi ::
* Performs a outer hash join for two child relations. When the output RDD of this operator is
* being constructed, a Spark job is asynchronously started to calculate the values for the
* broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
* relation is not shuffled.
*/
@DeveloperApi
case class BroadcastHashOuterJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashOuterJoin {

val timeout = {
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
Duration.Inf
} else {
timeoutValue.seconds
}
}

override def requiredChildDistribution =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

private[this] lazy val (buildPlan, streamedPlan) = joinType match {
case RightOuter => (left, right)
case LeftOuter => (right, left)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

private[this] lazy val (buildKeys, streamedKeys) = joinType match {
case RightOuter => (leftKeys, rightKeys)
case LeftOuter => (rightKeys, leftKeys)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}

@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
// buildHashTable uses code-generated rows as keys, which are not serializable
val hashed = new GeneralHashedRelation(
buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)))
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

override def doExecute(): RDD[InternalRow] = {
val broadcastRelation = Await.result(broadcastFuture, timeout)

streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
val keyGenerator = newProjection(streamedKeys, streamedPlan.output)

joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
})

case RightOuter =>
streamedIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
})

case x =>
throw new IllegalArgumentException(
s"BroadcastHashOuterJoin should not take $x as the JoinType")
}
}
}
}

object BroadcastHashOuterJoin {

private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128))
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,32 @@ package org.apache.spark.sql.execution.joins

import java.util.{HashMap => JavaHashMap}

import org.apache.spark.rdd.RDD

import scala.collection.JavaConversions._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer

/**
* :: DeveloperApi ::
* Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory.
*/
@DeveloperApi
case class HashOuterJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = joinType match {
trait HashOuterJoin {
self: SparkPlan =>

val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val joinType: JoinType
val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan

override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def output: Seq[Attribute] = {
joinType match {
case LeftOuter =>
Expand All @@ -68,8 +58,8 @@ case class HashOuterJoin(
}
}

@transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null)
@transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow]
@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()

@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
Expand All @@ -80,7 +70,7 @@ case class HashOuterJoin(
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.

private[this] def leftOuterIterator(
protected[this] def leftOuterIterator(
key: InternalRow,
joinedRow: JoinedRow,
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
Expand All @@ -89,7 +79,7 @@ case class HashOuterJoin(
val temp = rightIter.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
}
if (temp.size == 0) {
if (temp.isEmpty) {
joinedRow.withRight(rightNullRow).copy :: Nil
} else {
temp
Expand All @@ -101,18 +91,17 @@ case class HashOuterJoin(
ret.iterator
}

private[this] def rightOuterIterator(
protected[this] def rightOuterIterator(
key: InternalRow,
leftIter: Iterable[InternalRow],
joinedRow: JoinedRow): Iterator[InternalRow] = {

val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = leftIter.collect {
case l if boundCondition(joinedRow.withLeft(l)) =>
joinedRow.copy
joinedRow.copy()
}
if (temp.size == 0) {
if (temp.isEmpty) {
joinedRow.withLeft(leftNullRow).copy :: Nil
} else {
temp
Expand All @@ -124,10 +113,9 @@ case class HashOuterJoin(
ret.iterator
}

private[this] def fullOuterIterator(
protected[this] def fullOuterIterator(
key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow],
joinedRow: JoinedRow): Iterator[InternalRow] = {

if (!key.anyNull) {
// Store the positions of records in right, if one of its associated row satisfy
// the join condition.
Expand Down Expand Up @@ -171,7 +159,7 @@ case class HashOuterJoin(
}
}

private[this] def buildHashTable(
protected[this] def buildHashTable(
iter: Iterator[InternalRow],
keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]()
Expand All @@ -190,43 +178,4 @@ case class HashOuterJoin(

hashTable
}

protected override def doExecute(): RDD[InternalRow] = {
val joinedRow = new JoinedRow()
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)

joinType match {
case LeftOuter =>
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
val keyGenerator = newProjection(leftKeys, left.output)
leftIter.flatMap( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
})

case RightOuter =>
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val keyGenerator = newProjection(rightKeys, right.output)
rightIter.flatMap ( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
})

case FullOuter =>
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
(leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key,
leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow)
}

case x =>
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ import org.apache.spark.util.collection.CompactBuffer
private[joins] sealed trait HashedRelation {
def get(key: InternalRow): CompactBuffer[InternalRow]

def getOrElse(
key: InternalRow,
default: CompactBuffer[InternalRow]): CompactBuffer[InternalRow] = {
val v = get(key)
if (v eq null) default else v
}

// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
Expand Down
Loading

0 comments on commit b5a4efa

Please sign in to comment.