Skip to content

Commit

Permalink
use BytesToBytesMap for broadcast join
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 23, 2015
1 parent a721ee5 commit fc221e0
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable
}
}

private def update(position: Int, value: T): Unit = {
protected def update(position: Int, value: T): Unit = {
if (position < 0 || position >= curSize) {
throw new IndexOutOfBoundsException
}
Expand Down Expand Up @@ -125,7 +125,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable
}

/** Increase our size to newSize and grow the backing array if needed. */
private def growToSize(newSize: Int): Unit = {
protected def growToSize(newSize: Int): Unit = {
if (newSize < 0) {
throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
package org.apache.spark.sql.execution.joins

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
import org.apache.spark.util.collection.CompactBuffer


Expand Down Expand Up @@ -149,31 +155,141 @@ private[joins] object HashedRelation {
}
}

/**
* An extended CompactBuffer that could grow and update.
*/
class MutableCompactBuffer[T: ClassTag] extends CompactBuffer[T] {
override def growToSize(newSize: Int) = super.growToSize(newSize)
override def update(i: Int, v: T) = super.update(i, v)
}

/**
* A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
* sequence of values.
*
* TODO(davies): use BytesToBytesMap
*/
private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
extends HashedRelation with Externalizable {

def this() = this(null) // Needed for serialization

// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
@transient private[this] var binaryMap: BytesToBytesMap = _

// A pool of compact buffers to reduce memory garbage
@transient private[this] val bufferPool = new ThreadLocal[MutableCompactBuffer[UnsafeRow]]

override def get(key: InternalRow): CompactBuffer[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
// Thanks to type eraser
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]

if (binaryMap != null) {
// Used in Broadcast join
val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
unsafeKey.getSizeInBytes)
if (loc.isDefined) {
// thread-local buffer
var buffer = bufferPool.get()
if (buffer == null) {
buffer = new MutableCompactBuffer[UnsafeRow]
bufferPool.set(buffer)
}

val base = loc.getValueAddress.getBaseObject
var offset = loc.getValueAddress.getBaseOffset
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
var i = 0
while (offset < last) {
val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
offset += 8

// try to re-use the UnsafeRow in buffer, to reduce garbage
buffer.growToSize(i + 1)
if (buffer(i) == null) {
buffer(i) = new UnsafeRow
}
buffer(i).pointTo(base, offset, numFields, sizeInBytes, null)
i += 1
offset += sizeInBytes
}
buffer.asInstanceOf[CompactBuffer[InternalRow]]
} else {
null
}

} else {
// Use the JavaHashMap in Local mode or ShuffleHashJoin
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
}
}

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
out.writeInt(hashTable.size())

val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val key = entry.getKey
val values = entry.getValue

// write all the values as single byte array
var totalSize = 0L
var i = 0
while (i < values.size) {
totalSize += values(i).getSizeInBytes + 4 + 4
i += 1
}
assert(totalSize < Integer.MAX_VALUE, "values are too big")

// [key size] [values size] [key bytes] [values bytes]
out.writeInt(key.getSizeInBytes)
out.writeInt(totalSize.toInt)
out.write(key.getBytes)
i = 0
while (i < values.size) {
// [num of fields] [num of bytes] [row bytes]
// write the integer in native order, so they can be read by UNSAFE.getInt()
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
out.writeInt(values(i).length())
out.writeInt(values(i).getSizeInBytes)
} else {
out.writeInt(Integer.reverseBytes(values(i).length()))
out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
}
out.write(values(i).getBytes)
i += 1
}
}
}

override def readExternal(in: ObjectInput): Unit = {
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision

var i = 0
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
while (i < nKeys) {
val keySize = in.readInt()
val valuesSize = in.readInt()
if (keySize > keyBuffer.size) {
keyBuffer = new Array[Byte](keySize)
}
in.readFully(keyBuffer, 0, keySize)
if (valuesSize > valuesBuffer.size) {
valuesBuffer = new Array[Byte](valuesSize)
}
in.readFully(valuesBuffer, 0, valuesSize)

// put it into binary map
val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
assert(!loc.isDefined, "Duplicated key found!")
loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
i += 1
}
}
}

Expand All @@ -195,7 +311,6 @@ private[joins] object UnsafeHashedRelation {
rowSchema: StructType,
sizeEstimate: Int): HashedRelation = {

// TODO: Use BytesToBytesMap.
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
val toUnsafe = UnsafeProjection.create(rowSchema)
val keyGenerator = UnsafeProjection.create(buildKeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.sql.execution.joins

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer


Expand Down Expand Up @@ -80,8 +81,13 @@ class HashedRelationSuite extends SparkFunSuite {
data2 += unsafeData(2).copy()
assert(hashed.get(unsafeData(2)) === data2)

val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
.asInstanceOf[UnsafeHashedRelation]
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
Expand Down

0 comments on commit fc221e0

Please sign in to comment.