From b02b1d128fd8108a9af64e0cd7db65a7fcacca26 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Apr 2016 21:56:35 -0700 Subject: [PATCH 1/9] LongToUnsafeRowMap --- .../execution/joins/BroadcastHashJoin.scala | 6 +- .../spark/sql/execution/joins/HashJoin.scala | 11 +- .../sql/execution/joins/HashedRelation.scala | 476 +++++++++++------- .../execution/joins/ShuffledHashJoin.scala | 48 +- .../BenchmarkWholeStageCodegen.scala | 6 +- .../spark/sql/execution/ExchangeSuite.scala | 8 +- .../execution/joins/HashedRelationSuite.scala | 39 +- 7 files changed, 344 insertions(+), 250 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 67ac9e94ff2aa..982021fe3d5c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -51,10 +51,8 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode( - canJoinKeyFitWithinLong, - rewriteKeyExpr(buildKeys), - buildPlan.output) + val key = rewriteKeyExpr(buildKeys).map(BindReferences.bindReference(_, buildPlan.output)) + val mode = HashedRelationBroadcastMode(key) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b7c0f3e7d13f1..b1f0ef8d3a640 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -84,17 +84,8 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 - // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same - // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys - // with two same ints have hash code 0, we rotate the bits of second one. - val rotated = if (e.dataType == IntegerType) { - // (e >>> 15) | (e << 17) - BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) - } else { - e - } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 5ccb435686f23..6cbfc2011cab8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} +import org.apache.spark.sql.types.DecimalType.Expression +import org.apache.spark.sql.types.LongType import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow @@ -109,16 +111,14 @@ private[execution] object HashedRelation { * Note: The caller should make sure that these InternalRow are different objects. */ def apply( - canJoinKeyFitWithinLong: Boolean, - input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int = 64): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int = 64): HashedRelation = { - if (canJoinKeyFitWithinLong) { - LongHashedRelation(input, keyGenerator, sizeEstimate) + if (key.length == 1 && key.head.dataType == LongType) { + LongHashedRelation(input, key, sizeEstimate) } else { - UnsafeHashedRelation( - input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + UnsafeHashedRelation(input, key, sizeEstimate) } } } @@ -276,7 +276,7 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - keyGenerator: UnsafeProjection, + key: Seq[Expression], sizeEstimate: Int): HashedRelation = { val taskMemoryManager = if (TaskContext.get() != null) { @@ -300,6 +300,7 @@ private[joins] object UnsafeHashedRelation { pageSizeBytes) // Create a mapping of buildKeys -> rows + val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] @@ -321,144 +322,321 @@ private[joins] object UnsafeHashedRelation { } } +object LongToUnsafeRowMap { + // the largest prime that below 2^n + val LARGEST_PRIMES = { + // https://primes.utm.edu/lists/2small/0bit.html + val diffs = Seq( + 0, 1, 1, 3, 1, 3, 1, 5, + 3, 3, 9, 3, 1, 3, 19, 15, + 1, 5, 1, 3, 9, 3, 15, 3, + 39, 5, 39, 57, 3, 35, 1, 5 + ) + val primes = new Array[Int](32) + primes(0) = 1 + var power2 = 1 + (1 until 32).foreach { i => + power2 *= 2 + primes(i) = power2 - diffs(i) + } + primes + } + + val DENSE_FACTOR = 0.2 +} + /** - * An interface for a hashed relation that the key is a Long. + * + * @param capacity */ -private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Iterator[InternalRow] = { - get(key.getLong(0)) +final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { + import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ + // The actual capacity of map, is a prime number. + private var cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ + sys.error(s"Can't create map with capacity $capacity") } - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + // The array to store the key and offset of UnsafeRow in the page + // [key1] [offset1 | size1] [key2] [offset | size2] ... + private var array = new Array[Long](cap * 2) + // The page to store all bytes of UnsafeRow + private var page = new Array[Byte](1 << 20) // 1M + // Current write cursor in the page + private var cursor = Platform.BYTE_ARRAY_OFFSET + private var numValues = 0 + private var numKeys = 0 + // Whether all the keys are unique or not. + private var isUnique = true + + private var minKey = Long.MaxValue + private var maxKey = Long.MinValue + private var isDense = false + + def this() = this(0) // needed by serializer + + def keyIsUnique: Boolean = isUnique + + def getTotalMemoryConsumption: Long = { + array.length * 8 + page.length + } + + private def getSlot(key: Long): Int = { + var s = (key % cap).toInt * 2 + if (s < 0) { + s += cap * 2 + } + s } -} -private[joins] final class GeneralLongHashedRelation( - private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) - extends LongHashedRelation with Externalizable { + def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >= 0 && idx < array.length && array(idx) > 0) { + val pointer = array(idx) + val offset = pointer >>> 32 + val size = pointer & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + return resultRow + } + } else { + var pos = getSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + val pointer = array(pos + 1) + val offset = pointer >>> 32 + val size = pointer & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + return resultRow + } + pos += 2 + if (pos == array.length) { + pos = 0 + } + } + } + null + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >=0 && key <= maxKey && array(idx) > 0) { + return new Iterator[UnsafeRow] { + var pointer = array(idx) + override def hasNext: Boolean = pointer != 0 + override def next(): UnsafeRow = { + val offset = pointer >>> 32 + val size = pointer & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + pointer = Platform.getLong(page, offset + size) + resultRow + } + } + } + } else { + var pos = getSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return new Iterator[UnsafeRow] { + var pointer = array(pos + 1) + override def hasNext: Boolean = pointer != 0 + override def next(): UnsafeRow = { + val offset = pointer >>> 32 + val size = pointer & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + pointer = Platform.getLong(page, offset + size) + resultRow + } + } + } + pos += 2 + if (pos == array.length) { + pos = 0 + } + } + } + null + } - override def keyIsUnique: Boolean = false + def append(key: Long, row: UnsafeRow): Unit = { + if (key < minKey) { + minKey = key + } + if (key > maxKey) { + maxKey = key + } - override def asReadOnlyCopy(): GeneralLongHashedRelation = - new GeneralLongHashedRelation(hashTable) + if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + // TODO: memory manager + if (page.length > (1L << 31)) { + sys.error("Can't allocate a page that is larger than 2G") + } + val newPage = new Array[Byte](page.length * 2) + System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + page = newPage + } + val offset = cursor + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + cursor += row.getSizeInBytes + Platform.putLong(page, cursor, 0) + cursor += 8 + numValues += 1 + updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + } - override def get(key: Long): Iterator[InternalRow] = { - val rows = hashTable.get(key) - if (rows != null) { - rows.toIterator + private def updateIndex(key: Long, address: Long): Unit = { + var pos = getSlot(key) + while (array(pos + 1) != 0 && array(pos) != key) { + pos += 2 + if (pos == array.length) { + pos = 0 + } + } + if (array(pos + 1) == 0) { + array(pos) = key + array(pos + 1) = address + numKeys += 1 + if (numKeys * 2 > cap) { + grow() + } } else { - null + var addr = array(pos + 1) + var pointer = (addr >>> 32) + (addr & 0xffffffffL) + while (Platform.getLong(page, pointer) != 0) { + addr = Platform.getLong(page, pointer) + pointer = (addr >>> 32) + (addr & 0xffffffffL) + } + Platform.putLong(page, pointer, address) + isUnique = false + } + } + + private def grow(): Unit = { + val old_cap = cap + val old_array = array + cap = LARGEST_PRIMES.find(_ > cap).getOrElse{ + sys.error(s"Can't grow map any more than $cap") + } + numKeys = 0 + // println(s"grow from ${old_cap} to ${cap}") + array = new Array[Long](cap * 2) + var i = 0 + while (i < old_cap * 2) { + if (old_array(i + 1) > 0) { + updateIndex(old_array(i), old_array(i + 1)) + } + i += 2 + } + } + + def optimize(): Unit = { + if (numKeys > (maxKey - minKey) * DENSE_FACTOR) { + val denseArray = new Array[Long]((maxKey - minKey + 1).toInt) + var i = 0 + while (i < array.length) { + if (array(i + 1) > 0) { + val idx = (array(i) - minKey).toInt + denseArray(idx) = array(i + 1) + } + i += 2 + } + array = denseArray + isDense = true } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeBoolean(isUnique) + out.writeBoolean(isDense) + out.writeLong(minKey) + out.writeLong(maxKey) + out.writeInt(cap) + + out.writeInt(array.length) + val buffer = new Array[Byte](64 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size + } + + val used = cursor - Platform.BYTE_ARRAY_OFFSET + out.writeInt(used) + out.write(page, 0, used) } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + isUnique = in.readBoolean() + isDense = in.readBoolean() + minKey = in.readLong() + maxKey = in.readLong() + cap = in.readInt() + + val length = in.readInt() + array = new Array[Long](length) + val buffer = new Array[Byte](64 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + offset += size + } + + val numBytes = in.readInt() + page = new Array[Byte](numBytes) + in.readFully(page) } } -/** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ -private[joins] final class LongArrayRelation( - private var numFields: Int, - private var start: Long, - private var offsets: Array[Int], - private var sizes: Array[Int], - private var bytes: Array[Byte] - ) extends LongHashedRelation with Externalizable { +private[joins] class LongHashedRelation( + private var nFields: Int, + private var map: LongToUnsafeRowMap) + extends HashedRelation with Externalizable { - // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, 0L, null, null, null) + private var resultRow: UnsafeRow = new UnsafeRow(nFields) - override def keyIsUnique: Boolean = true + // Needed for serialization (it is public to make Java serialization work) + def this() = this(0, null) - override def asReadOnlyCopy(): LongArrayRelation = { - new LongArrayRelation(numFields, start, offsets, sizes, bytes) - } + override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) override def getMemorySize: Long = { - offsets.length * 4 + sizes.length * 4 + bytes.length + map.getTotalMemoryConsumption } - override def get(key: Long): Iterator[InternalRow] = { - val row = getValue(key) - if (row != null) { - Seq(row).toIterator - } else { + override def get(key: InternalRow): Iterator[InternalRow] = { + if (key.isNullAt(0)) { null + } else { + get(key.getLong(0)) } } + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def get(key: Long): Iterator[InternalRow] = + map.get(key, resultRow) - var resultRow = new UnsafeRow(numFields) override def getValue(key: Long): InternalRow = { - val idx = (key - start).toInt - if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - resultRow - } else { - null - } + map.getValue(key, resultRow) } + override def keyIsUnique: Boolean = map.keyIsUnique + override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(numFields) - out.writeLong(start) - out.writeInt(sizes.length) - var i = 0 - while (i < sizes.length) { - out.writeInt(sizes(i)) - i += 1 - } - out.writeInt(bytes.length) - out.write(bytes) + out.writeInt(nFields) + out.writeObject(map) } override def readExternal(in: ObjectInput): Unit = { - numFields = in.readInt() - resultRow = new UnsafeRow(numFields) - start = in.readLong() - val length = in.readInt() - // read sizes of rows - sizes = new Array[Int](length) - offsets = new Array[Int](length) - var i = 0 - var offset = 0 - while (i < length) { - offsets(i) = offset - sizes(i) = in.readInt() - offset += sizes(i) - i += 1 - } - // read all the bytes - val total = in.readInt() - assert(total == offset) - bytes = new Array[Byte](total) - in.readFully(bytes) + nFields = in.readInt() + resultRow = new UnsafeRow(nFields) + map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } } @@ -466,96 +644,44 @@ private[joins] final class LongArrayRelation( * Create hashed relation with key that is long. */ private[joins] object LongHashedRelation { - - val DENSE_FACTOR = 0.2 - def apply( - input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int): LongHashedRelation = { - // TODO: use LongToBytesMap for better memory efficiency - val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(sizeEstimate) + val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows var numFields = 0 - var keyIsUnique = true - var minKey = Long.MaxValue - var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val key = rowKey.getLong(0) - minKey = math.min(minKey, key) - maxKey = math.max(maxKey, key) - val existingMatchList = hashTable.get(key) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(key, newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += unsafeRow - } - } - - if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 + map.append(key, unsafeRow) } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 - } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - } else { - new GeneralLongHashedRelation(hashTable) } + map.optimize() + new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode( - canJoinKeyFitWithinLong: Boolean, - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) + extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - val generator = UnsafeProjection.create(keys, attributes) - HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) + HashedRelation(rows.iterator, canonicalizedKey, rows.length) } - private lazy val canonicalizedKeys: Seq[Expression] = { - keys.map { e => - BindReferences.bindReference(e.canonicalized, attributes) - } + private lazy val canonicalizedKey: Seq[Expression] = { + key.map { e => e.canonicalized } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => - canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && - canonicalizedKeys == m.canonicalizedKeys + case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c63faacf33989..028640c61cfc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.memory.MemoryMode +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -58,46 +57,13 @@ case class ShuffledHashJoin( private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { val context = TaskContext.get() - if (!canJoinKeyFitWithinLong) { - // build BytesToBytesMap - val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) - // This relation is usually used until the end of task. - context.addTaskCompletionListener((t: TaskContext) => - relation.close() - ) - return relation - } - - // try to acquire some memory for the hash table, it could trigger other operator to free some - // memory. The memory acquired here will mostly be used until the end of task. - val memoryManager = context.taskMemoryManager() - var acquired = 0L - var used = 0L + val key = rewriteKeyExpr(buildKeys).map(BindReferences.bindReference(_, buildPlan.output)) + val relation = HashedRelation(iter, key) + // This relation is usually used until the end of task. context.addTaskCompletionListener((t: TaskContext) => - memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) + relation.close() ) - - val copiedIter = iter.map { row => - // It's hard to guess what's exactly memory will be used, we have a rough guess here. - // TODO: use LongToBytesMap instead of HashMap for memory efficiency - // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers - val needed = 150 + row.getSizeInBytes - if (needed > acquired - used) { - val got = memoryManager.acquireExecutionMemory( - Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) - acquired += got - if (got < needed) { - throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + - "hash join, please use sort merge join by setting " + - "spark.sql.join.preferSortMergeJoin=true") - } - } - used += needed - // HashedRelation requires that the UnsafeRow should be separate objects. - row.copy() - } - - HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) + relation } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 3566ef304327a..8c5c37bd60f3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -46,7 +46,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { def runBenchmark(name: String, values: Long)(f: => Unit): Unit = { val benchmark = new Benchmark(name, values) - Seq(false, true).foreach { enabled => + Seq(true).foreach { enabled => benchmark.addCase(s"$name codegen=$enabled") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) f @@ -165,7 +165,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - ignore("broadcast hash join") { + test("broadcast hash join") { val N = 20 << 20 val M = 1 << 16 val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) @@ -310,7 +310,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - ignore("shuffle hash join") { + test("shuffle hash join") { val N = 4 << 20 sqlContext.setConf("spark.sql.shuffle.partitions", "2") sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 9680f3a008a59..17f2343cf971e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode - val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) - val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) + val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) + val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) @@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) - val hashMode = HashedRelationBroadcastMode(true, output, plan.output) + val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchange(hashMode, plan) val hashMode2 = - HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) + HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchange(hashMode2, plan) val exchange4 = ReusedExchange(output, exchange3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ed87a99439521..6004ed16196c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -37,8 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val unsafeData = data.map(toUnsafe(_).copy()) val buildKey = Seq(BoundReference(0, IntegerType, false)) - val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -100,31 +99,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } - test("LongArrayRelation") { + test("LongToUnsafeRowMap") { val unsafeProj = UnsafeProjection.create( Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) - val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) - assert(longRelation.isInstanceOf[LongArrayRelation]) - val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] + val key = Seq(BoundReference(0, IntegerType, false)) + val longRelation = LongHashedRelation(rows.iterator, key, 10) + assert(longRelation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) + val row = longRelation.getValue(i) assert(row.getInt(0) === i) assert(row.getInt(1) === i + 1) } + val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100) + assert(!longRelation2.keyIsUnique) + (0 until 100).foreach { i => + val rows = longRelation2.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - longArrayRelation.writeExternal(out) + longRelation2.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) - val relation = new LongArrayRelation() + val relation = new LongHashedRelation() relation.readExternal(in) + assert(!relation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) - assert(row.getInt(0) === i) - assert(row.getInt(1) === i + 1) + val rows = relation.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) } } } From 1b68736b72e0cc710a01ccbfd46e957b2fa5132f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 5 Apr 2016 17:08:47 -0700 Subject: [PATCH 2/9] update benchmark & cleanup --- .../aggregate/TungstenAggregate.scala | 3 +- .../sql/execution/joins/HashedRelation.scala | 292 +++++++++++------- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../BenchmarkWholeStageCodegen.scala | 137 +++++--- .../execution/joins/HashedRelationSuite.scala | 15 +- 5 files changed, 305 insertions(+), 144 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 60027edc7c396..040d60f01d2f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -441,7 +441,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + ctx.addMutableState(hashMapClassName, hashMapTerm, s"") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -453,6 +453,7 @@ case class TungstenAggregate( ctx.addNewFunction(doAgg, s""" private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6cbfc2011cab8..bc18b291a0b6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -20,18 +20,16 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} -import org.apache.spark.sql.types.DecimalType.Expression -import org.apache.spark.sql.types.LongType -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} -import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} -import org.apache.spark.util.collection.CompactBuffer /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete @@ -84,23 +82,7 @@ private[execution] sealed trait HashedRelation { /** * Release any used resources. */ - def close(): Unit = {} - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { - out.writeInt(serialized.length) // Write the length of serialized bytes first - out.write(serialized) - } - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def readBytes(in: ObjectInput): Array[Byte] = { - val serializedSize = in.readInt() // Read the length of serialized bytes first - val bytes = new Array[Byte](serializedSize) - in.readFully(bytes) - bytes - } + def close(): Unit } private[execution] object HashedRelation { @@ -111,14 +93,24 @@ private[execution] object HashedRelation { * Note: The caller should make sure that these InternalRow are different objects. */ def apply( - input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int = 64): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int = 64, + taskMemoryManager: TaskMemoryManager = null): HashedRelation = { + val mm = Option(taskMemoryManager).getOrElse { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } if (key.length == 1 && key.head.dataType == LongType) { - LongHashedRelation(input, key, sizeEstimate) + LongHashedRelation(input, key, sizeEstimate, mm) } else { - UnsafeHashedRelation(input, key, sizeEstimate) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } @@ -277,19 +269,9 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], key: Seq[Expression], - sizeEstimate: Int): HashedRelation = { + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { - val taskMemoryManager = if (TaskContext.get() != null) { - TaskContext.get().taskMemoryManager() - } else { - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -346,117 +328,178 @@ object LongToUnsafeRowMap { } /** - * - * @param capacity + * An hash map mapping from key of Long to UnsafeRow. */ -final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { +final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable { import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ - // The actual capacity of map, is a prime number. - private var cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ - sys.error(s"Can't create map with capacity $capacity") - } - // The array to store the key and offset of UnsafeRow in the page - // [key1] [offset1 | size1] [key2] [offset | size2] ... - private var array = new Array[Long](cap * 2) - // The page to store all bytes of UnsafeRow - private var page = new Array[Byte](1 << 20) // 1M - // Current write cursor in the page + + // Whether the keys are stored in dense mode or not. + private var isDense = false + + // The minimum value of keys. + private var minKey = Long.MaxValue + + // The Maxinum value of keys. + private var maxKey = Long.MinValue + + // Sparse mode: the actual capacity of map, is a prime number. + private var cap: Int = 0 + + // The array to store the key and offset of UnsafeRow in the page. + // + // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... + // Dense mode: [offset1 | size1] [offset2 | size2] + private var array: Array[Long] = null + + // The page to store all bytes of UnsafeRow and the pointer to next rows. + // [row1][pointer1] [row2][pointer2] + private var page: Array[Byte] = null + + // Current write cursor in the page. private var cursor = Platform.BYTE_ARRAY_OFFSET + + // The total number of values of all keys. private var numValues = 0 + + // The number of unique keys. private var numKeys = 0 - // Whether all the keys are unique or not. - private var isUnique = true - private var minKey = Long.MaxValue - private var maxKey = Long.MinValue - private var isDense = false + def this() = this(null, 0) // needed by serializer + + private def acquireMemory(size: Long): Unit = { + // do not support spilling + val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + if (got < size) { + mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation") + } + } - def this() = this(0) // needed by serializer + private def init(): Unit = { + if (mm != null) { + cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ + sys.error(s"Can't create map with capacity $capacity") + } + acquireMemory(cap * 2 * 8 + (1 << 20)) + array = new Array[Long](cap * 2) + page = new Array[Byte](1 << 20) // 1M bytes + } + } - def keyIsUnique: Boolean = isUnique + init() + def spill(size: Long, trigger: MemoryConsumer): Long = { + 0L + } + + /** + * Returns whether all the keys are unique. + * @return + */ + def keyIsUnique: Boolean = numKeys == numValues + + /** + * Returns total memory consumption. + */ def getTotalMemoryConsumption: Long = { array.length * 8 + page.length } + /** + * Returns the slot of array that store the keys (sparse mode). + */ private def getSlot(key: Long): Int = { - var s = (key % cap).toInt * 2 - if (s < 0) { - s += cap * 2 + var s = (key % cap).toInt + while (s < 0) { + s += cap } - s + s * 2 } + /** + * Returns the single UnsafeRow for given key, or null if not found. + */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { val idx = (key - minKey).toInt - if (idx >= 0 && idx < array.length && array(idx) > 0) { - val pointer = array(idx) - val offset = pointer >>> 32 - val size = pointer & 0xffffffffL + if (idx >= 0 && key <= maxKey && array(idx) > 0) { + val addr = array(idx) + val offset = addr >>> 32 + val size = addr & 0xffffffffL resultRow.pointTo(page, offset, size.toInt) return resultRow } } else { var pos = getSlot(key) + var step = 1 while (array(pos + 1) != 0) { if (array(pos) == key) { - val pointer = array(pos + 1) - val offset = pointer >>> 32 - val size = pointer & 0xffffffffL + val addr = array(pos + 1) + val offset = addr >>> 32 + val size = addr & 0xffffffffL resultRow.pointTo(page, offset, size.toInt) return resultRow } - pos += 2 - if (pos == array.length) { - pos = 0 + pos += 2 * step + step += 1 + if (pos >= cap) { + pos -= cap } } } null } + /** + * Returns an iterator for all the values for the given key, or null if no value found. + */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { val idx = (key - minKey).toInt if (idx >=0 && key <= maxKey && array(idx) > 0) { return new Iterator[UnsafeRow] { - var pointer = array(idx) - override def hasNext: Boolean = pointer != 0 + var addr = array(idx) + override def hasNext: Boolean = addr != 0 override def next(): UnsafeRow = { - val offset = pointer >>> 32 - val size = pointer & 0xffffffffL + val offset = addr >>> 32 + val size = addr & 0xffffffffL resultRow.pointTo(page, offset, size.toInt) - pointer = Platform.getLong(page, offset + size) + addr = Platform.getLong(page, offset + size) resultRow } } } } else { var pos = getSlot(key) + var step = 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return new Iterator[UnsafeRow] { - var pointer = array(pos + 1) - override def hasNext: Boolean = pointer != 0 + var addr = array(pos + 1) + override def hasNext: Boolean = addr != 0 override def next(): UnsafeRow = { - val offset = pointer >>> 32 - val size = pointer & 0xffffffffL + val offset = addr >>> 32 + val size = addr & 0xffffffffL resultRow.pointTo(page, offset, size.toInt) - pointer = Platform.getLong(page, offset + size) + addr = Platform.getLong(page, offset + size) resultRow } } } - pos += 2 - if (pos == array.length) { - pos = 0 + pos += 2 * step + step += 1 + if (pos >= cap) { + pos -= cap } } } null } + /** + * Appends the key and row into this map. + */ def append(key: Long, row: UnsafeRow): Unit = { if (key < minKey) { minKey = key @@ -466,13 +509,15 @@ final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { } if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { - // TODO: memory manager - if (page.length > (1L << 31)) { + val used = page.length + if (used * 2L > (1L << 31)) { sys.error("Can't allocate a page that is larger than 2G") } - val newPage = new Array[Byte](page.length * 2) + acquireMemory(used * 2) + val newPage = new Array[Byte](used * 2) System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) page = newPage + mm.releaseExecutionMemory(used, MemoryMode.ON_HEAP, this) } val offset = cursor Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) @@ -483,22 +528,30 @@ final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) } + /** + * Update the address in array for given key. + */ private def updateIndex(key: Long, address: Long): Unit = { var pos = getSlot(key) + var step = 1 while (array(pos + 1) != 0 && array(pos) != key) { - pos += 2 - if (pos == array.length) { - pos = 0 + pos += 2 * step + step += 1 + if (pos >= cap) { + pos -= cap } } if (array(pos + 1) == 0) { + // this is the first value for this key, put the address in array. array(pos) = key array(pos + 1) = address numKeys += 1 if (numKeys * 2 > cap) { - grow() + // reach half of the capacity + growArray() } } else { + // there is another value for this key, put the address at the end of final value. var addr = array(pos + 1) var pointer = (addr >>> 32) + (addr & 0xffffffffL) while (Platform.getLong(page, pointer) != 0) { @@ -506,18 +559,17 @@ final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { pointer = (addr >>> 32) + (addr & 0xffffffffL) } Platform.putLong(page, pointer, address) - isUnique = false } } - private def grow(): Unit = { + private def growArray(): Unit = { val old_cap = cap - val old_array = array + var old_array = array cap = LARGEST_PRIMES.find(_ > cap).getOrElse{ sys.error(s"Can't grow map any more than $cap") } numKeys = 0 - // println(s"grow from ${old_cap} to ${cap}") + acquireMemory(cap * 2 * 8) array = new Array[Long](cap * 2) var i = 0 while (i < old_cap * 2) { @@ -526,11 +578,25 @@ final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { } i += 2 } + old_array = null // release the reference to old array + mm.releaseExecutionMemory(old_cap * 2 * 8, MemoryMode.ON_HEAP, this) } + /** + * Try to turn the map into dense mode, which is faster to probe. + */ def optimize(): Unit = { - if (numKeys > (maxKey - minKey) * DENSE_FACTOR) { - val denseArray = new Array[Long]((maxKey - minKey + 1).toInt) + val range = maxKey - minKey + // Convert to dense mode if it does not require more memory or could fit within L1 cache + if (range < cap * 2 || range < 1024) { + try { + acquireMemory((range + 1) * 8) + } catch { + case e: SparkException => + // there is no enough memory to convert + return + } + val denseArray = new Array[Long]((range + 1).toInt) var i = 0 while (i < array.length) { if (array(i + 1) > 0) { @@ -541,14 +607,26 @@ final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { } array = denseArray isDense = true + mm.releaseExecutionMemory(cap * 2 * 8, MemoryMode.ON_HEAP, this) + } + } + + def free(): Unit = { + if (page != null) { + mm.releaseExecutionMemory(page.length, MemoryMode.ON_HEAP, this) + page = null + } + if (array != null) { + mm.releaseExecutionMemory(array.length * 8, MemoryMode.ON_HEAP, this) } } override def writeExternal(out: ObjectOutput): Unit = { - out.writeBoolean(isUnique) out.writeBoolean(isDense) out.writeLong(minKey) out.writeLong(maxKey) + out.writeInt(numKeys) + out.writeInt(numValues) out.writeInt(cap) out.writeInt(array.length) @@ -568,10 +646,11 @@ final class LongToUnsafeRowMap(capacity: Int) extends Externalizable { } override def readExternal(in: ObjectInput): Unit = { - isUnique = in.readBoolean() isDense = in.readBoolean() minKey = in.readLong() maxKey = in.readLong() + numKeys = in.readInt() + numValues = in.readInt() cap = in.readInt() val length = in.readInt() @@ -628,6 +707,10 @@ private[joins] class LongHashedRelation( override def keyIsUnique: Boolean = map.keyIsUnique + override def close(): Unit = { + map.free() + } + override def writeExternal(out: ObjectOutput): Unit = { out.writeInt(nFields) out.writeObject(map) @@ -647,9 +730,10 @@ private[joins] object LongHashedRelation { def apply( input: Iterator[InternalRow], key: Seq[Expression], - sizeEstimate: Int): LongHashedRelation = { + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(sizeEstimate) + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 028640c61cfc2..5a7af5f6cef63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -58,7 +58,7 @@ case class ShuffledHashJoin( private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { val context = TaskContext.get() val key = rewriteKeyExpr(buildKeys).map(BindReferences.bindReference(_, buildPlan.output)) - val relation = HashedRelation(iter, key) + val relation = HashedRelation(iter, key, taskMemoryManager = context.taskMemoryManager()) // This relation is usually used until the end of task. context.addTaskCompletionListener((t: TaskContext) => relation.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 8c5c37bd60f3b..27f8e62757b57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.util.HashMap +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext @@ -46,7 +47,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { def runBenchmark(name: String, values: Long)(f: => Unit): Unit = { val benchmark = new Benchmark(name, values) - Seq(true).foreach { enabled => + Seq(false, true).foreach { enabled => benchmark.addCase(s"$name codegen=$enabled") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) f @@ -165,7 +166,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - test("broadcast hash join") { + ignore("broadcast hash join") { val N = 20 << 20 val M = 1 << 16 val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) @@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X - Join w long codegen=true 275 / 352 76.2 13.1 19.4X + Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X + Join w long codegen=true 321 / 371 65.3 15.3 9.3X */ runBenchmark("Join w long duplicated", N) { @@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X - Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X + Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X + Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X */ val dim2 = broadcast(sqlContext.range(M) @@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X - Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X + Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X + Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X */ val dim3 = broadcast(sqlContext.range(M) @@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X - outer join w long codegen=true 216 / 226 97.2 10.3 26.3X + outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X + outer join w long codegen=true 261 / 276 80.5 12.4 11.7X */ runBenchmark("semi join w long", N) { @@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X - semi join w long codegen=true 211 / 229 99.2 10.1 22.2X + semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X + semi join w long codegen=true 237 / 244 88.3 11.3 8.1X */ } @@ -310,7 +311,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - test("shuffle hash join") { + ignore("shuffle hash join") { val N = 4 << 20 sqlContext.setConf("spark.sql.shuffle.partitions", "2") sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000") @@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X - shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X + shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X + shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X */ } @@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 10 << 20 + val N = 20 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) - benchmark.addCase("hash") { iter => + benchmark.addCase("UnsafeRowhash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) @@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("murmur3 hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } + } + benchmark.addCase("fast hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 var s = 0 while (i < N) { - key.setInt(0, i % 1000) - val h = Murmur3_x86_32.hashLong(i % 1000, 42) + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) s += h i += 1 } @@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < N) { + val numKeys = 65536 + while (i < numKeys) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, Murmur3_x86_32.hashLong(i % 65536, 42)) - if (loc.isDefined) { - value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - value.setInt(0, value.getInt(0) + 1) - i += 1 - } else { + if (!loc.isDefined) { loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 } } } @@ -536,16 +601,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 112 / 116 93.2 10.7 1.0X - fast hash 65 / 69 160.9 6.2 1.7X - arrayEqual 66 / 69 159.1 6.3 1.7X - Java HashMap (Long) 137 / 182 76.3 13.1 0.8X - Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X - Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X - BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X - BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X - Aggregate HashMap 56 / 62 187.9 5.3 2.0X - */ + UnsafeRow hash 267 / 284 78.4 12.8 1.0X + murmur3 hash 102 / 129 205.5 4.9 2.6X + fast hash 79 / 96 263.8 3.8 3.4X + arrayEqual 164 / 172 128.2 7.8 1.6X + Java HashMap (Long) 321 / 399 65.4 15.3 0.8X + Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X + Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X + LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X + LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X + BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X + BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X + Aggregate HashMap 121 / 131 173.3 5.8 2.2X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 6004ed16196c7..371a9ed617d65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,14 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { + val mm = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()) + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -104,7 +113,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) val key = Seq(BoundReference(0, IntegerType, false)) - val longRelation = LongHashedRelation(rows.iterator, key, 10) + val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) assert(longRelation.keyIsUnique) (0 until 100).foreach { i => val row = longRelation.getValue(i) @@ -112,7 +121,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(row.getInt(1) === i + 1) } - val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100) + val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) assert(!longRelation2.keyIsUnique) (0 until 100).foreach { i => val rows = longRelation2.get(i).toArray From 87e32f368f775a37f9b9ad3f7c18f29ee05db12f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 5 Apr 2016 17:22:57 -0700 Subject: [PATCH 3/9] fix style --- .../org/apache/spark/sql/execution/joins/HashedRelation.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index bc18b291a0b6b..fb561d4f72a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.{HashMap => JavaHashMap} -import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow From a3b45d89a04e18513ee1aa288c182632506eb761 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 8 Apr 2016 22:32:45 -0700 Subject: [PATCH 4/9] clenaup & add comments --- .../execution/joins/BroadcastHashJoin.scala | 16 +- .../spark/sql/execution/joins/HashJoin.scala | 20 +- .../sql/execution/joins/HashedRelation.scala | 186 +++++++++++------- .../execution/joins/ShuffledHashJoin.scala | 7 +- 4 files changed, 131 insertions(+), 98 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 982021fe3d5c6..2d269c4dde193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType import org.apache.spark.util.collection.CompactBuffer /** @@ -51,8 +52,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val key = rewriteKeyExpr(buildKeys).map(BindReferences.bindReference(_, buildPlan.output)) - val mode = HashedRelationBroadcastMode(key) + val mode = HashedRelationBroadcastMode(buildKeys) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -67,7 +67,7 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } } @@ -103,7 +103,7 @@ case class BroadcastHashJoin( ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.getMemorySize()); + | incPeakExecutionMemory($relationTerm.estimatedSize()); """.stripMargin) (broadcastRelation, relationTerm) } @@ -116,15 +116,13 @@ case class BroadcastHashJoin( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (canJoinKeyFitWithinLong) { + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { // generate the join key as Long - val expr = rewriteKeyExpr(streamedKeys).head - val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + val ev = streamedKeys.head.gen(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b1f0ef8d3a640..8f5b9e2a279d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -59,9 +59,13 @@ trait HashJoin { case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) + protected lazy val (buildKeys, streamedKeys) = { + val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + buildSide match { + case BuildLeft => (lkeys, rkeys) + case BuildRight => (rkeys, lkeys) + } } /** @@ -96,17 +100,11 @@ trait HashJoin { keyExpr :: Nil } - protected lazy val canJoinKeyFitWithinLong: Boolean = { - val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) - val key = rewriteKeyExpr(buildKeys) - sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] - } - protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) + UnsafeProjection.create(buildKeys) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) + UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index fb561d4f72a55..4959f60dab275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.{KnownSizeEstimation, Utils} * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns matched rows. * @@ -72,11 +72,6 @@ private[execution] sealed trait HashedRelation { */ def asReadOnlyCopy(): HashedRelation - /** - * Returns the size of used memory. - */ - def getMemorySize: Long = 1L // to make the test happy - /** * Release any used resources. */ @@ -87,8 +82,6 @@ private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. - * - * Note: The caller should make sure that these InternalRow are different objects. */ def apply( input: Iterator[InternalRow], @@ -123,7 +116,7 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with KnownSizeEstimation with Externalizable { + extends HashedRelation with Externalizable { private[joins] def this() = this(0, null) // Needed for serialization @@ -132,10 +125,6 @@ private[joins] class UnsafeHashedRelation( override def asReadOnlyCopy(): UnsafeHashedRelation = new UnsafeHashedRelation(numFields, binaryMap) - override def getMemorySize: Long = { - binaryMap.getTotalMemoryConsumption - } - override def estimatedSize: Long = { binaryMap.getTotalMemoryConsumption } @@ -302,7 +291,7 @@ private[joins] object UnsafeHashedRelation { } } -object LongToUnsafeRowMap { +private[joins] object LongToUnsafeRowMap { // the largest prime that below 2^n val LARGEST_PRIMES = { // https://primes.utm.edu/lists/2small/0bit.html @@ -321,14 +310,41 @@ object LongToUnsafeRowMap { } primes } - - val DENSE_FACTOR = 0.2 } /** - * An hash map mapping from key of Long to UnsafeRow. + * An append-only hash map mapping from key of Long to UnsafeRow. + * + * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array + * (`page`) in this format: + * + * [bytes of row1][address1][bytes of row2][address1] ... + * + * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key + * could have multiple values. the address at the end of last value for every key is 0. + * + * The keys and addresses of their values could be stored in two modes: + * + * 1) sparse mode: the keys and addresses are stored in `array` as: + * + * [key1][address1][key2][address2]...[] + * + * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 + * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address + * hash collision. + * + * 2) dense mode: all the addresses are packed into a single array of long, as: + * + * [address1] [address2] ... + * + * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is + * determined by `key1 - minKey`. + * + * The map is created as sparse mode, then key-value could be appended into it. Once finish + * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * to probe. */ -final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) +private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) extends MemoryConsumer(mm) with Externalizable { import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ @@ -363,7 +379,18 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) // The number of unique keys. private var numKeys = 0 - def this() = this(null, 0) // needed by serializer + // needed by serializer + def this() = { + this( + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0), + 0) + } private def acquireMemory(size: Long): Unit = { // do not support spilling @@ -374,6 +401,10 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) } } + private def freeMemory(size: Long): Unit = { + mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) + } + private def init(): Unit = { if (mm != null) { cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ @@ -393,7 +424,6 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) /** * Returns whether all the keys are unique. - * @return */ def keyIsUnique: Boolean = numKeys == numValues @@ -409,12 +439,19 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) */ private def getSlot(key: Long): Int = { var s = (key % cap).toInt - while (s < 0) { + if (s < 0) { s += cap } s * 2 } + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { + val offset = address >>> 32 + val size = address & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + resultRow + } + /** * Returns the single UnsafeRow for given key, or null if not found. */ @@ -422,33 +459,42 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) if (isDense) { val idx = (key - minKey).toInt if (idx >= 0 && key <= maxKey && array(idx) > 0) { - val addr = array(idx) - val offset = addr >>> 32 - val size = addr & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - return resultRow + return getRow(array(idx), resultRow) } } else { var pos = getSlot(key) var step = 1 while (array(pos + 1) != 0) { if (array(pos) == key) { - val addr = array(pos + 1) - val offset = addr >>> 32 - val size = addr & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - return resultRow + return getRow(array(pos + 1), resultRow) } pos += 2 * step step += 1 - if (pos >= cap) { - pos -= cap + if (pos >= array.length) { + pos -= array.length } } } null } + /** + * Returns an interator of UnsafeRow for multiple linked values. + */ + private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + new Iterator[UnsafeRow] { + var addr = address + override def hasNext: Boolean = addr != 0 + override def next(): UnsafeRow = { + val offset = addr >>> 32 + val size = addr & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + addr = Platform.getLong(page, offset + size) + resultRow + } + } + } + /** * Returns an iterator for all the values for the given key, or null if no value found. */ @@ -456,39 +502,19 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) if (isDense) { val idx = (key - minKey).toInt if (idx >=0 && key <= maxKey && array(idx) > 0) { - return new Iterator[UnsafeRow] { - var addr = array(idx) - override def hasNext: Boolean = addr != 0 - override def next(): UnsafeRow = { - val offset = addr >>> 32 - val size = addr & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - addr = Platform.getLong(page, offset + size) - resultRow - } - } + return valueIter(array(idx), resultRow) } } else { var pos = getSlot(key) var step = 1 while (array(pos + 1) != 0) { if (array(pos) == key) { - return new Iterator[UnsafeRow] { - var addr = array(pos + 1) - override def hasNext: Boolean = addr != 0 - override def next(): UnsafeRow = { - val offset = addr >>> 32 - val size = addr & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - addr = Platform.getLong(page, offset + size) - resultRow - } - } + return valueIter(array(pos + 1), resultRow) } pos += 2 * step step += 1 - if (pos >= cap) { - pos -= cap + if (pos >= array.length) { + pos -= array.length } } } @@ -506,6 +532,7 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) maxKey = key } + // There is 8 bytes for the pointer to next value if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { val used = page.length if (used * 2L > (1L << 31)) { @@ -515,8 +542,10 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) val newPage = new Array[Byte](used * 2) System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) page = newPage - mm.releaseExecutionMemory(used, MemoryMode.ON_HEAP, this) + freeMemory(used) } + + // copy the bytes of UnsafeRow val offset = cursor Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) cursor += row.getSizeInBytes @@ -535,8 +564,8 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) while (array(pos + 1) != 0 && array(pos) != key) { pos += 2 * step step += 1 - if (pos >= cap) { - pos -= cap + if (pos >= array.length) { + pos -= array.length } } if (array(pos + 1) == 0) { @@ -570,14 +599,14 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) acquireMemory(cap * 2 * 8) array = new Array[Long](cap * 2) var i = 0 - while (i < old_cap * 2) { + while (i < old_array.length) { if (old_array(i + 1) > 0) { updateIndex(old_array(i), old_array(i + 1)) } i += 2 } old_array = null // release the reference to old array - mm.releaseExecutionMemory(old_cap * 2 * 8, MemoryMode.ON_HEAP, this) + freeMemory(old_cap * 2 * 8) } /** @@ -586,7 +615,7 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) def optimize(): Unit = { val range = maxKey - minKey // Convert to dense mode if it does not require more memory or could fit within L1 cache - if (range < cap * 2 || range < 1024) { + if (range < array.length || range < 1024) { try { acquireMemory((range + 1) * 8) } catch { @@ -603,19 +632,24 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) } i += 2 } + val old_length = array.length array = denseArray isDense = true - mm.releaseExecutionMemory(cap * 2 * 8, MemoryMode.ON_HEAP, this) + freeMemory(old_length * 8) } } + /** + * Free all the memory acquired by this map. + */ def free(): Unit = { if (page != null) { - mm.releaseExecutionMemory(page.length, MemoryMode.ON_HEAP, this) + freeMemory(page.length) page = null } if (array != null) { - mm.releaseExecutionMemory(array.length * 8, MemoryMode.ON_HEAP, this) + freeMemory(array.length * 8) + array = null } } @@ -628,7 +662,7 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) out.writeInt(cap) out.writeInt(array.length) - val buffer = new Array[Byte](64 << 10) + val buffer = new Array[Byte](4 << 10) var offset = Platform.LONG_ARRAY_OFFSET val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET while (offset < end) { @@ -653,7 +687,7 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) val length = in.readInt() array = new Array[Long](length) - val buffer = new Array[Byte](64 << 10) + val buffer = new Array[Byte](4 << 10) var offset = Platform.LONG_ARRAY_OFFSET val end = length * 8 + Platform.LONG_ARRAY_OFFSET while (offset < end) { @@ -670,9 +704,8 @@ final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) } private[joins] class LongHashedRelation( - private var nFields: Int, - private var map: LongToUnsafeRowMap) - extends HashedRelation with Externalizable { + private var nFields: Int, + private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { private var resultRow: UnsafeRow = new UnsafeRow(nFields) @@ -681,7 +714,7 @@ private[joins] class LongHashedRelation( override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) - override def getMemorySize: Long = { + override def estimatedSize: Long = { map.getTotalMemoryConsumption } @@ -692,8 +725,13 @@ private[joins] class LongHashedRelation( get(key.getLong(0)) } } + override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + if (key.isNullAt(0)) { + null + } else { + getValue(key.getLong(0)) + } } override def get(key: Long): Iterator[InternalRow] = @@ -740,7 +778,7 @@ private[joins] object LongHashedRelation { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { + if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) map.append(key, unsafeRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5a7af5f6cef63..c267dc82de838 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -55,10 +55,9 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { + private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val context = TaskContext.get() - val key = rewriteKeyExpr(buildKeys).map(BindReferences.bindReference(_, buildPlan.output)) - val relation = HashedRelation(iter, key, taskMemoryManager = context.taskMemoryManager()) + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) // This relation is usually used until the end of task. context.addTaskCompletionListener((t: TaskContext) => relation.close() @@ -69,7 +68,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => - val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) + val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } } From 028eec539960b92a0b0188a8ba3d8f149230e5a5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 8 Apr 2016 22:35:52 -0700 Subject: [PATCH 5/9] fix style --- .../spark/sql/execution/BenchmarkWholeStageCodegen.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 59cfc2de0a0f8..5caedd7dfb29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -611,7 +611,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X - Aggregate HashMap 121 / 131 173.3 5.8 2.2X */ + Aggregate HashMap 121 / 131 173.3 5.8 2.2X + */ benchmark.run() } From 32c2165e455fc1e1dd864b60a3f1b0b16991d50b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 8 Apr 2016 22:53:50 -0700 Subject: [PATCH 6/9] fix style --- .../apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 5caedd7dfb29b..352fd07d0e8b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution import java.util.HashMap -import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.AggregateHashMap From 87660e24d2f3638e7cdacdc8ae106581ffa0c409 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 9 Apr 2016 15:01:44 -0700 Subject: [PATCH 7/9] change array length to power of 2 --- .../sql/execution/joins/HashedRelation.scala | 106 ++++++------------ 1 file changed, 35 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 4959f60dab275..7a6bb8247b35a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -291,27 +291,6 @@ private[joins] object UnsafeHashedRelation { } } -private[joins] object LongToUnsafeRowMap { - // the largest prime that below 2^n - val LARGEST_PRIMES = { - // https://primes.utm.edu/lists/2small/0bit.html - val diffs = Seq( - 0, 1, 1, 3, 1, 3, 1, 5, - 3, 3, 9, 3, 1, 3, 19, 15, - 1, 5, 1, 3, 9, 3, 15, 3, - 39, 5, 39, 57, 3, 35, 1, 5 - ) - val primes = new Array[Int](32) - primes(0) = 1 - var power2 = 1 - (1 until 32).foreach { i => - power2 *= 2 - primes(i) = power2 - diffs(i) - } - primes - } -} - /** * An append-only hash map mapping from key of Long to UnsafeRow. * @@ -343,28 +322,27 @@ private[joins] object LongToUnsafeRowMap { * The map is created as sparse mode, then key-value could be appended into it. Once finish * appending, caller could all optimize() to try to turn the map into dense mode, which is faster * to probe. + * + * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ */ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) extends MemoryConsumer(mm) with Externalizable { - import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ // Whether the keys are stored in dense mode or not. private var isDense = false - // The minimum value of keys. + // The minimum key private var minKey = Long.MaxValue - // The Maxinum value of keys. + // The maxinum key private var maxKey = Long.MinValue - // Sparse mode: the actual capacity of map, is a prime number. - private var cap: Int = 0 - // The array to store the key and offset of UnsafeRow in the page. // // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... // Dense mode: [offset1 | size1] [offset2 | size2] private var array: Array[Long] = null + private var mask: Int = 0 // The page to store all bytes of UnsafeRow and the pointer to next rows. // [row1][pointer1] [row2][pointer2] @@ -407,11 +385,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap private def init(): Unit = { if (mm != null) { - cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ - sys.error(s"Can't create map with capacity $capacity") - } - acquireMemory(cap * 2 * 8 + (1 << 20)) - array = new Array[Long](cap * 2) + var n = 1 + while (n < capacity) n *= 2 + acquireMemory(n * 2 * 8 + (1 << 20)) + array = new Array[Long](n * 2) + mask = n * 2 - 2 page = new Array[Byte](1 << 20) // 1M bytes } } @@ -435,14 +413,18 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } /** - * Returns the slot of array that store the keys (sparse mode). + * Returns the first slot of array that store the keys (sparse mode). */ - private def getSlot(key: Long): Int = { - var s = (key % cap).toInt - if (s < 0) { - s += cap - } - s * 2 + private def firstSlot(key: Long): Int = { + val h = key * 0x9E3779B9L + (h ^ (h >> 32)).toInt & mask + } + + /** + * Returns the next probe in the array. + */ + private def nextSlot(pos: Int): Int = { + (pos + 2) & mask } private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { @@ -462,17 +444,12 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap return getRow(array(idx), resultRow) } } else { - var pos = getSlot(key) - var step = 1 + var pos = firstSlot(key) while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } + pos = nextSlot(pos) } } null @@ -505,17 +482,12 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap return valueIter(array(idx), resultRow) } } else { - var pos = getSlot(key) - var step = 1 + var pos = firstSlot(key) while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } + pos = nextSlot(pos) } } null @@ -559,21 +531,16 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap * Update the address in array for given key. */ private def updateIndex(key: Long, address: Long): Unit = { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0 && array(pos) != key) { - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } + var pos = firstSlot(key) + while (array(pos) != key && array(pos + 1) != 0) { + pos = nextSlot(pos) } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. array(pos) = key array(pos + 1) = address numKeys += 1 - if (numKeys * 2 > cap) { + if (numKeys * 4 > array.length) { // reach half of the capacity growArray() } @@ -590,14 +557,12 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } private def growArray(): Unit = { - val old_cap = cap var old_array = array - cap = LARGEST_PRIMES.find(_ > cap).getOrElse{ - sys.error(s"Can't grow map any more than $cap") - } + val n = array.length numKeys = 0 - acquireMemory(cap * 2 * 8) - array = new Array[Long](cap * 2) + acquireMemory(n * 2 * 8) + array = new Array[Long](n * 2) + mask = n * 2 - 2 var i = 0 while (i < old_array.length) { if (old_array(i + 1) > 0) { @@ -606,7 +571,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap i += 2 } old_array = null // release the reference to old array - freeMemory(old_cap * 2 * 8) + freeMemory(n * 8) } /** @@ -659,7 +624,6 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap out.writeLong(maxKey) out.writeInt(numKeys) out.writeInt(numValues) - out.writeInt(cap) out.writeInt(array.length) val buffer = new Array[Byte](4 << 10) @@ -683,10 +647,10 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap maxKey = in.readLong() numKeys = in.readInt() numValues = in.readInt() - cap = in.readInt() val length = in.readInt() array = new Array[Long](length) + mask = length - 2 val buffer = new Array[Byte](4 << 10) var offset = Platform.LONG_ARRAY_OFFSET val end = length * 8 + Platform.LONG_ARRAY_OFFSET From 9e1a110bae674f93de89660ab42e4071c8f46fc8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 9 Apr 2016 13:26:46 -0700 Subject: [PATCH 8/9] improve building hashed relaton with long key --- .../apache/spark/sql/execution/joins/HashJoin.scala | 10 ++++------ .../spark/sql/execution/joins/HashedRelation.scala | 12 ++++-------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4c912d371e05e..d6feedc27244b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException - -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { self: SparkPlan => @@ -60,6 +56,8 @@ trait HashJoin { } protected lazy val (buildKeys, streamedKeys) = { + require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), + "Join keys from two sides should have same types") val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) buildSide match { @@ -73,7 +71,7 @@ trait HashJoin { * * If not, returns the original expressions. */ - def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { var keyExpr: Expression = null var width = 0 keys.foreach { e => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 7a6bb8247b35a..68b5486faaf50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -545,14 +545,10 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap growArray() } } else { - // there is another value for this key, put the address at the end of final value. - var addr = array(pos + 1) - var pointer = (addr >>> 32) + (addr & 0xffffffffL) - while (Platform.getLong(page, pointer) != 0) { - addr = Platform.getLong(page, pointer) - pointer = (addr >>> 32) + (addr & 0xffffffffL) - } - Platform.putLong(page, pointer, address) + // there are some values for this key, put the address in the front of them. + val pointer = (address >>> 32) + (address & 0xffffffffL) + Platform.putLong(page, pointer, array(pos + 1)) + array(pos + 1) = address } } From 8774ebc9244a9edc35c86971d239f2110cb7e13e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 9 Apr 2016 13:49:17 -0700 Subject: [PATCH 9/9] insert new value in the beginning of linked list in BytesToBytesMap --- .../apache/spark/unsafe/map/BytesToBytesMap.java | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 32958be7a7fd7..6807710f9fef1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -716,7 +716,8 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); offset += vlen; - Platform.putLong(base, offset, 0); + // put this value at the beginning of the list + Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0); // --- Update bookkeeping data structures ---------------------------------------------------- offset = currentPage.getBaseOffset(); @@ -724,17 +725,12 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff pageCursor += recordLength; final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( currentPage, recordOffset); + longArray.set(pos * 2, storedKeyAddress); + updateAddressesAndSizes(storedKeyAddress); numValues++; - if (isDefined) { - // put this pair at the end of chain - while (nextValue()) { /* do nothing */ } - Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress); - nextValue(); // point to new added value - } else { + if (!isDefined) { numKeys++; - longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); - updateAddressesAndSizes(storedKeyAddress); isDefined = true; if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {