Skip to content

Commit

Permalink
remove thread local cache and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 28, 2015
1 parent 1c5ad8d commit fd09528
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable
}
}

protected def update(position: Int, value: T): Unit = {
private def update(position: Int, value: T): Unit = {
if (position < 0 || position >= curSize) {
throw new IndexOutOfBoundsException
}
Expand Down Expand Up @@ -125,7 +125,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable
}

/** Increase our size to newSize and grow the backing array if needed. */
protected def growToSize(newSize: Int): Unit = {
private def growToSize(newSize: Int): Unit = {
if (newSize < 0) {
throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
Expand Down Expand Up @@ -159,16 +157,21 @@ private[joins] object HashedRelation {
}

/**
* An extended CompactBuffer that could grow and update.
*/
private[joins] class MutableCompactBuffer[T: ClassTag] extends CompactBuffer[T] {
override def growToSize(newSize: Int): Unit = super.growToSize(newSize)
override def update(i: Int, v: T): Unit = super.update(i, v)
}

/**
* A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
* sequence of values.
* A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
* into a sequence of values.
*
* When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
* BytesToBytesMap for better memory performance (multiple values for the same are stored as a
* continuous byte array.
*
* It's serialized in the following format:
* [number of keys]
* [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
* ...
*
* All the values are serialized as following:
* [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
* ...
*/
private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
Expand All @@ -179,9 +182,6 @@ private[joins] final class UnsafeHashedRelation(
// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
@transient private[this] var binaryMap: BytesToBytesMap = _

// A pool of compact buffers to reduce memory garbage
@transient private[this] val bufferPool = new ThreadLocal[MutableCompactBuffer[UnsafeRow]]

override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]

Expand All @@ -190,29 +190,19 @@ private[joins] final class UnsafeHashedRelation(
val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
unsafeKey.getSizeInBytes)
if (loc.isDefined) {
// thread-local buffer
var buffer = bufferPool.get()
if (buffer == null) {
buffer = new MutableCompactBuffer[UnsafeRow]
bufferPool.set(buffer)
}
val buffer = CompactBuffer[UnsafeRow]()

val base = loc.getValueAddress.getBaseObject
var offset = loc.getValueAddress.getBaseOffset
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
var i = 0
while (offset < last) {
val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
offset += 8

// try to re-use the UnsafeRow in buffer, to reduce garbage
buffer.growToSize(i + 1)
if (buffer(i) == null) {
buffer(i) = new UnsafeRow
}
buffer(i).pointTo(base, offset, numFields, sizeInBytes, null)
i += 1
val row = new UnsafeRow
row.pointTo(base, offset, numFields, sizeInBytes, null)
buffer += row
offset += sizeInBytes
}
buffer
Expand Down

0 comments on commit fd09528

Please sign in to comment.