Skip to content

Commit

Permalink
address comments in #7480
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 23, 2015
1 parent 46f1f22 commit 5eb1b5a
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = buildHashRelation(input.iterator)
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = buildHashRelation(input.iterator)
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash(
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

protected override def doExecute(): RDD[InternalRow] = {
val buildIter = right.execute().map(_.copy()).collect().toIterator
val input = right.execute().map(_.copy()).collect()

if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter)
val hashSet = buildKeyHashSet(input.toIterator)
val broadcastedRelation = sparkContext.broadcast(hashSet)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = buildHashRelation(buildIter)
val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

left.execute().mapPartitions { streamIter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ case class BroadcastNestedLoopJoin(
var streamRowMatched = false

while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
Expand Down Expand Up @@ -135,17 +134,26 @@ case class BroadcastNestedLoopJoin(
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
case (LeftOuter | FullOuter, BuildLeft) =>
buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
val joinedRow = new JoinedRow
joinedRow.withLeft(leftNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
buf += resultProjection(joinedRow.withRight(rel(i)))
}
i += 1
}
}
i += 1
case (LeftOuter | FullOuter, BuildLeft) =>
val joinedRow = new JoinedRow
joinedRow.withRight(rightNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
buf += resultProjection(joinedRow.withLeft(rel(i)))
}
i += 1
}
case _ =>
}
buf.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer


trait HashJoin {
Expand All @@ -44,16 +43,24 @@ trait HashJoin {

override def output: Seq[Attribute] = left.output ++ right.output

protected[this] def supportUnsafe: Boolean = {
protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}

override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def outputsUnsafeRows: Boolean = isUnsafeMode
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode

@transient protected lazy val buildSideKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}

@transient protected lazy val streamSideKeyGenerator: Projection =
if (supportUnsafe) {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newMutableProjection(streamedKeys, streamedPlan.output)()
Expand All @@ -65,18 +72,16 @@ trait HashJoin {
{
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
private[this] var currentHashMatches: Seq[InternalRow] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
private[this] val joinRow = new JoinedRow2
private[this] val resultProjection: Projection = {
if (supportUnsafe) {
private[this] val resultProjection: (InternalRow) => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}

Expand Down Expand Up @@ -122,12 +127,4 @@ trait HashJoin {
}
}
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,36 @@ trait HashOuterJoin {
s"HashOuterJoin should not take $x as the JoinType")
}

protected[this] def supportUnsafe: Boolean = {
protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && joinType != FullOuter
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}

override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def outputsUnsafeRows: Boolean = isUnsafeMode
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode

protected[this] def streamedKeyGenerator(): Projection = {
if (supportUnsafe) {
@transient protected lazy val buildKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}

@transient protected[this] lazy val streamedKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newProjection(streamedKeys, streamedPlan.output)
}
}

@transient private[this] lazy val resultProjection: Projection = {
if (supportUnsafe) {
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}

Expand Down Expand Up @@ -230,12 +236,4 @@ trait HashOuterJoin {

hashTable
}

protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ trait HashSemiJoin {
protected[this] def supportUnsafe: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
&& UnsafeProjection.canSupport(rightKeys)
&& UnsafeProjection.canSupport(left.schema))
&& UnsafeProjection.canSupport(left.schema)
&& UnsafeProjection.canSupport(right.schema))
}

override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def canProcessSafeRows: Boolean = !supportUnsafe

@transient protected lazy val leftKeyGenerator: Projection =
if (supportUnsafe) {
Expand Down Expand Up @@ -87,14 +89,6 @@ trait HashSemiJoin {
})
}

protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, rightKeys, right)
} else {
HashedRelation(buildIter, newProjection(rightKeys, right.output))
}
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.util.collection.CompactBuffer


Expand All @@ -38,7 +37,7 @@ import org.apache.spark.util.collection.CompactBuffer
* object.
*/
private[joins] sealed trait HashedRelation {
def get(key: InternalRow): CompactBuffer[InternalRow]
def get(key: InternalRow): Seq[InternalRow]

// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
Expand All @@ -65,9 +64,9 @@ private[joins] final class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {

def this() = this(null) // Needed for serialization
private def this() = this(null) // Needed for serialization

override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key)
override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
Expand All @@ -87,9 +86,9 @@ private[joins]
final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends HashedRelation with Externalizable {

def this() = this(null) // Needed for serialization
private def this() = this(null) // Needed for serialization

override def get(key: InternalRow): CompactBuffer[InternalRow] = {
override def get(key: InternalRow): Seq[InternalRow] = {
val v = hashTable.get(key)
if (v eq null) null else CompactBuffer(v)
}
Expand All @@ -115,6 +114,10 @@ private[joins] object HashedRelation {
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {

if (keyGenerator.isInstanceOf[UnsafeProjection]) {
return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}

// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
var currentRow: InternalRow = null
Expand Down Expand Up @@ -158,7 +161,7 @@ private[joins] object HashedRelation {
/**
* An extended CompactBuffer that could grow and update.
*/
class MutableCompactBuffer[T: ClassTag] extends CompactBuffer[T] {
private[joins] class MutableCompactBuffer[T: ClassTag] extends CompactBuffer[T] {
override def growToSize(newSize: Int): Unit = super.growToSize(newSize)
override def update(i: Int, v: T): Unit = super.update(i, v)
}
Expand All @@ -171,15 +174,15 @@ private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
extends HashedRelation with Externalizable {

def this() = this(null) // Needed for serialization
private[joins] def this() = this(null) // Needed for serialization

// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
@transient private[this] var binaryMap: BytesToBytesMap = _

// A pool of compact buffers to reduce memory garbage
@transient private[this] val bufferPool = new ThreadLocal[MutableCompactBuffer[UnsafeRow]]

override def get(key: InternalRow): CompactBuffer[InternalRow] = {
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]

if (binaryMap != null) {
Expand Down Expand Up @@ -212,14 +215,14 @@ private[joins] final class UnsafeHashedRelation(
i += 1
offset += sizeInBytes
}
buffer.asInstanceOf[CompactBuffer[InternalRow]]
buffer
} else {
null
}

} else {
// Use the JavaHashMap in Local mode or ShuffleHashJoin
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
hashTable.get(unsafeKey)
}
}

Expand Down Expand Up @@ -297,32 +300,14 @@ private[joins] object UnsafeHashedRelation {

def apply(
input: Iterator[InternalRow],
buildKeys: Seq[Expression],
buildPlan: SparkPlan,
sizeEstimate: Int = 64): HashedRelation = {
val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
}

// Used for tests
def apply(
input: Iterator[InternalRow],
buildKeys: Seq[Expression],
rowSchema: StructType,
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {

val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
val toUnsafe = UnsafeProjection.create(rowSchema)
val keyGenerator = UnsafeProjection.create(buildKeys)

// Create a mapping of buildKeys -> rows
while (input.hasNext) {
val currentRow = input.next()
val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
currentRow.asInstanceOf[UnsafeRow]
} else {
toUnsafe(currentRow)
}
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
val rowKey = keyGenerator(unsafeRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ case class LeftSemiJoinHash(
val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet)
} else {
val hashRelation = buildHashRelation(buildIter)
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation)
}
}
Expand Down
Loading

0 comments on commit 5eb1b5a

Please sign in to comment.