From fc221e078307d32fb0333d59e6df29e06474fbc4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 22 Jul 2015 01:23:02 -0700 Subject: [PATCH] use BytesToBytesMap for broadcast join --- .../spark/util/collection/CompactBuffer.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 129 +++++++++++++++++- .../execution/joins/HashedRelationSuite.scala | 14 +- 3 files changed, 134 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index 4d43d8d5cc8d8..c47efe1e80443 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -51,7 +51,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } } - private def update(position: Int, value: T): Unit = { + protected def update(position: Int, value: T): Unit = { if (position < 0 || position >= curSize) { throw new IndexOutOfBoundsException } @@ -125,7 +125,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } /** Increase our size to newSize and grow the backing array if needed. */ - private def growToSize(newSize: Int): Unit = { + protected def growToSize(newSize: Int): Unit = { if (newSize < 0) { throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8d5731afd59b8..1b78b986eaa1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,12 +18,18 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} import org.apache.spark.util.collection.CompactBuffer @@ -149,12 +155,17 @@ private[joins] object HashedRelation { } } +/** + * An extended CompactBuffer that could grow and update. + */ +class MutableCompactBuffer[T: ClassTag] extends CompactBuffer[T] { + override def growToSize(newSize: Int) = super.growToSize(newSize) + override def update(i: Int, v: T) = super.update(i, v) +} /** * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a * sequence of values. - * - * TODO(davies): use BytesToBytesMap */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) @@ -162,18 +173,123 @@ private[joins] final class UnsafeHashedRelation( def this() = this(null) // Needed for serialization + // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + @transient private[this] var binaryMap: BytesToBytesMap = _ + + // A pool of compact buffers to reduce memory garbage + @transient private[this] val bufferPool = new ThreadLocal[MutableCompactBuffer[UnsafeRow]] + override def get(key: InternalRow): CompactBuffer[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - // Thanks to type eraser - hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + + if (binaryMap != null) { + // Used in Broadcast join + val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes) + if (loc.isDefined) { + // thread-local buffer + var buffer = bufferPool.get() + if (buffer == null) { + buffer = new MutableCompactBuffer[UnsafeRow] + bufferPool.set(buffer) + } + + val base = loc.getValueAddress.getBaseObject + var offset = loc.getValueAddress.getBaseOffset + val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + var i = 0 + while (offset < last) { + val numFields = PlatformDependent.UNSAFE.getInt(base, offset) + val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + offset += 8 + + // try to re-use the UnsafeRow in buffer, to reduce garbage + buffer.growToSize(i + 1) + if (buffer(i) == null) { + buffer(i) = new UnsafeRow + } + buffer(i).pointTo(base, offset, numFields, sizeInBytes, null) + i += 1 + offset += sizeInBytes + } + buffer.asInstanceOf[CompactBuffer[InternalRow]] + } else { + null + } + + } else { + // Use the JavaHashMap in Local mode or ShuffleHashJoin + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.size) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.size) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).length()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).length())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 + } + } } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + val nKeys = in.readInt() + // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory + val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision + + var i = 0 + var keyBuffer = new Array[Byte](1024) + var valuesBuffer = new Array[Byte](1024) + while (i < nKeys) { + val keySize = in.readInt() + val valuesSize = in.readInt() + if (keySize > keyBuffer.size) { + keyBuffer = new Array[Byte](keySize) + } + in.readFully(keyBuffer, 0, keySize) + if (valuesSize > valuesBuffer.size) { + valuesBuffer = new Array[Byte](valuesSize) + } + in.readFully(valuesBuffer, 0, valuesSize) + + // put it into binary map + val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + assert(!loc.isDefined, "Duplicated key found!") + loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + i += 1 + } } } @@ -195,7 +311,6 @@ private[joins] object UnsafeHashedRelation { rowSchema: StructType, sizeEstimate: Int): HashedRelation = { - // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) val toUnsafe = UnsafeProjection.create(rowSchema) val keyGenerator = UnsafeProjection.create(buildKeys) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9dd2220f0967e..6e80604fca450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.joins +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types.{StructField, StructType, IntegerType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -80,8 +81,13 @@ class HashedRelationSuite extends SparkFunSuite { data2 += unsafeData(2).copy() assert(hashed.get(unsafeData(2)) === data2) - val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) - .asInstanceOf[UnsafeHashedRelation] + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)