Skip to content

Commit

Permalink
[SPARK-9247] [SQL] Use BytesToBytesMap for broadcast join
Browse files Browse the repository at this point in the history
This PR introduce BytesToBytesMap to UnsafeHashedRelation, use it in executor for better performance.

It serialize all the key and values from java HashMap, put them into a BytesToBytesMap while deserializing. All the values for a same key are stored continuous to have better memory locality.

This PR also address the comments for #7480 , do some clean up.

Author: Davies Liu <davies@databricks.com>

Closes #7592 from davies/unsafe_map2 and squashes the following commits:

42c578a [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_map2
fd09528 [Davies Liu] remove thread local cache and update docs
1c5ad8d [Davies Liu] fix test
5eb1b5a [Davies Liu] address comments in #7480
46f1f22 [Davies Liu] fix style
fc221e0 [Davies Liu] use BytesToBytesMap for broadcast join
  • Loading branch information
Davies Liu authored and davies committed Jul 28, 2015
1 parent 198d181 commit 2182552
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = buildHashRelation(input.iterator)
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = buildHashRelation(input.iterator)
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash(
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

protected override def doExecute(): RDD[InternalRow] = {
val buildIter = right.execute().map(_.copy()).collect().toIterator
val input = right.execute().map(_.copy()).collect()

if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter)
val hashSet = buildKeyHashSet(input.toIterator)
val broadcastedRelation = sparkContext.broadcast(hashSet)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = buildHashRelation(buildIter)
val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

left.execute().mapPartitions { streamIter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin(
override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true

@transient private[this] lazy val resultProjection: Projection = {
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}

Expand Down Expand Up @@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin(
var streamRowMatched = false

while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
Expand Down Expand Up @@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin(
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
case (LeftOuter | FullOuter, BuildLeft) =>
buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
val joinedRow = new JoinedRow
joinedRow.withLeft(leftNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
buf += resultProjection(joinedRow.withRight(rel(i))).copy()
}
i += 1
}
}
i += 1
case (LeftOuter | FullOuter, BuildLeft) =>
val joinedRow = new JoinedRow
joinedRow.withRight(rightNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
}
i += 1
}
case _ =>
}
buf.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer


trait HashJoin {
Expand All @@ -44,16 +43,24 @@ trait HashJoin {

override def output: Seq[Attribute] = left.output ++ right.output

protected[this] def supportUnsafe: Boolean = {
protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}

override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def outputsUnsafeRows: Boolean = isUnsafeMode
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode

@transient protected lazy val buildSideKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}

@transient protected lazy val streamSideKeyGenerator: Projection =
if (supportUnsafe) {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newMutableProjection(streamedKeys, streamedPlan.output)()
Expand All @@ -65,18 +72,16 @@ trait HashJoin {
{
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
private[this] var currentHashMatches: Seq[InternalRow] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
private[this] val joinRow = new JoinedRow
private[this] val resultProjection: Projection = {
if (supportUnsafe) {
private[this] val resultProjection: (InternalRow) => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}

Expand Down Expand Up @@ -122,12 +127,4 @@ trait HashJoin {
}
}
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,36 @@ trait HashOuterJoin {
s"HashOuterJoin should not take $x as the JoinType")
}

protected[this] def supportUnsafe: Boolean = {
protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && joinType != FullOuter
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}

override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def outputsUnsafeRows: Boolean = isUnsafeMode
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode

protected[this] def streamedKeyGenerator(): Projection = {
if (supportUnsafe) {
@transient protected lazy val buildKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}

@transient protected[this] lazy val streamedKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newProjection(streamedKeys, streamedPlan.output)
}
}

@transient private[this] lazy val resultProjection: Projection = {
if (supportUnsafe) {
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}

Expand Down Expand Up @@ -230,12 +236,4 @@ trait HashOuterJoin {

hashTable
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ trait HashSemiJoin {
protected[this] def supportUnsafe: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
&& UnsafeProjection.canSupport(rightKeys)
&& UnsafeProjection.canSupport(left.schema))
&& UnsafeProjection.canSupport(left.schema)
&& UnsafeProjection.canSupport(right.schema))
}

override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def canProcessSafeRows: Boolean = !supportUnsafe

@transient protected lazy val leftKeyGenerator: Projection =
if (supportUnsafe) {
Expand Down Expand Up @@ -87,14 +89,6 @@ trait HashSemiJoin {
})
}

protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, rightKeys, right)
} else {
HashedRelation(buildIter, newProjection(rightKeys, right.output))
}
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
Expand Down
Loading

0 comments on commit 2182552

Please sign in to comment.