diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index c47efe1e80443..4d43d8d5cc8d8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -51,7 +51,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } } - protected def update(position: Int, value: T): Unit = { + private def update(position: Int, value: T): Unit = { if (position < 0 || position >= curSize) { throw new IndexOutOfBoundsException } @@ -125,7 +125,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } /** Increase our size to newSize and grow the backing array if needed. */ - protected def growToSize(newSize: Int): Unit = { + private def growToSize(newSize: Int): Unit = { if (newSize < 0) { throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2f9b3130fad3c..19718a164d9f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,8 +21,6 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import scala.reflect.ClassTag - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -159,16 +157,21 @@ private[joins] object HashedRelation { } /** - * An extended CompactBuffer that could grow and update. - */ -private[joins] class MutableCompactBuffer[T: ClassTag] extends CompactBuffer[T] { - override def growToSize(newSize: Int): Unit = super.growToSize(newSize) - override def update(i: Int, v: T): Unit = super.update(i, v) -} - -/** - * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a - * sequence of values. + * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key + * into a sequence of values. + * + * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use + * BytesToBytesMap for better memory performance (multiple values for the same are stored as a + * continuous byte array. + * + * It's serialized in the following format: + * [number of keys] + * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] + * ... + * + * All the values are serialized as following: + * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] + * ... */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) @@ -179,9 +182,6 @@ private[joins] final class UnsafeHashedRelation( // Use BytesToBytesMap in executor for better performance (it's created when deserialization) @transient private[this] var binaryMap: BytesToBytesMap = _ - // A pool of compact buffers to reduce memory garbage - @transient private[this] val bufferPool = new ThreadLocal[MutableCompactBuffer[UnsafeRow]] - override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] @@ -190,29 +190,19 @@ private[joins] final class UnsafeHashedRelation( val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, unsafeKey.getSizeInBytes) if (loc.isDefined) { - // thread-local buffer - var buffer = bufferPool.get() - if (buffer == null) { - buffer = new MutableCompactBuffer[UnsafeRow] - bufferPool.set(buffer) - } + val buffer = CompactBuffer[UnsafeRow]() val base = loc.getValueAddress.getBaseObject var offset = loc.getValueAddress.getBaseOffset val last = loc.getValueAddress.getBaseOffset + loc.getValueLength - var i = 0 while (offset < last) { val numFields = PlatformDependent.UNSAFE.getInt(base, offset) val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) offset += 8 - // try to re-use the UnsafeRow in buffer, to reduce garbage - buffer.growToSize(i + 1) - if (buffer(i) == null) { - buffer(i) = new UnsafeRow - } - buffer(i).pointTo(base, offset, numFields, sizeInBytes, null) - i += 1 + val row = new UnsafeRow + row.pointTo(base, offset, numFields, sizeInBytes, null) + buffer += row offset += sizeInBytes } buffer