From 21825529eae66293ec5d8638911303fa54944dd5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 28 Jul 2015 15:56:19 -0700 Subject: [PATCH] [SPARK-9247] [SQL] Use BytesToBytesMap for broadcast join This PR introduce BytesToBytesMap to UnsafeHashedRelation, use it in executor for better performance. It serialize all the key and values from java HashMap, put them into a BytesToBytesMap while deserializing. All the values for a same key are stored continuous to have better memory locality. This PR also address the comments for #7480 , do some clean up. Author: Davies Liu Closes #7592 from davies/unsafe_map2 and squashes the following commits: 42c578a [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_map2 fd09528 [Davies Liu] remove thread local cache and update docs 1c5ad8d [Davies Liu] fix test 5eb1b5a [Davies Liu] address comments in #7480 46f1f22 [Davies Liu] fix style fc221e0 [Davies Liu] use BytesToBytesMap for broadcast join --- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 2 +- .../joins/BroadcastLeftSemiJoinHash.scala | 6 +- .../joins/BroadcastNestedLoopJoin.scala | 36 ++-- .../spark/sql/execution/joins/HashJoin.scala | 35 ++-- .../sql/execution/joins/HashOuterJoin.scala | 34 ++-- .../sql/execution/joins/HashSemiJoin.scala | 14 +- .../sql/execution/joins/HashedRelation.scala | 166 ++++++++++++++---- .../execution/joins/LeftSemiJoinHash.scala | 2 +- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../joins/ShuffledHashOuterJoin.scala | 8 +- .../execution/joins/HashedRelationSuite.scala | 28 +-- 12 files changed, 214 insertions(+), 121 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index abaa4a6ce86a2..624efc1b1d734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c9d1a880f4ef4..77e7fe71009b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f71c0ce352904..a60593911f94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash( condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = right.execute().map(_.copy()).collect().toIterator + val input = right.execute().map(_.copy()).collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) + val hashSet = buildKeyHashSet(input.toIterator) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 700636966f8be..83b726a8e2897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: Projection = { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin( var streamRowMatched = false while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => @@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin( val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - buf += resultProjection(new JoinedRow(leftNulls, rel(i))) - case (LeftOuter | FullOuter, BuildLeft) => - buf += resultProjection(new JoinedRow(rel(i), rightNulls)) - case _ => + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + val joinedRow = new JoinedRow + joinedRow.withLeft(leftNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withRight(rel(i))).copy() + } + i += 1 } - } - i += 1 + case (LeftOuter | FullOuter, BuildLeft) => + val joinedRow = new JoinedRow + joinedRow.withRight(rightNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + } + i += 1 + } + case _ => } buf.toSeq } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 46ab5b0d1cc6d..6b3d1652923fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.util.collection.CompactBuffer trait HashJoin { @@ -44,16 +43,24 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + @transient protected lazy val buildSideKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } @transient protected lazy val streamSideKeyGenerator: Projection = - if (supportUnsafe) { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newMutableProjection(streamedKeys, streamedPlan.output)() @@ -65,18 +72,16 @@ trait HashJoin { { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: CompactBuffer[InternalRow] = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: Projection = { - if (supportUnsafe) { + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -122,12 +127,4 @@ trait HashJoin { } } } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6bf2f82954046..7e671e7914f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -75,30 +75,36 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode - protected[this] def streamedKeyGenerator(): Projection = { - if (supportUnsafe) { + @transient protected lazy val buildKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } + + @transient protected[this] lazy val streamedKeyGenerator: Projection = { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) } } - @transient private[this] lazy val resultProjection: Projection = { - if (supportUnsafe) { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -230,12 +236,4 @@ trait HashOuterJoin { hashTable } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 7f49264d40354..97fde8f975bfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -35,11 +35,13 @@ trait HashSemiJoin { protected[this] def supportUnsafe: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema)) + && UnsafeProjection.canSupport(left.schema) + && UnsafeProjection.canSupport(right.schema)) } - override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def outputsUnsafeRows: Boolean = supportUnsafe override def canProcessUnsafeRows: Boolean = supportUnsafe + override def canProcessSafeRows: Boolean = !supportUnsafe @transient protected lazy val leftKeyGenerator: Projection = if (supportUnsafe) { @@ -87,14 +89,6 @@ trait HashSemiJoin { }) } - protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, rightKeys, right) - } else { - HashedRelation(buildIter, newProjection(rightKeys, right.output)) - } - } - protected def hashSemiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8d5731afd59b8..9c058f1f72fe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.collection.CompactBuffer @@ -32,7 +35,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { - def get(key: InternalRow): CompactBuffer[InternalRow] + def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -59,9 +62,9 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key) + override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -81,9 +84,9 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } @@ -109,6 +112,10 @@ private[joins] object HashedRelation { keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { + if (keyGenerator.isInstanceOf[UnsafeProjection]) { + return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + } + // TODO: Use Spark's HashMap implementation. val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) var currentRow: InternalRow = null @@ -149,31 +156,133 @@ private[joins] object HashedRelation { } } - /** - * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a - * sequence of values. + * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key + * into a sequence of values. + * + * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use + * BytesToBytesMap for better memory performance (multiple values for the same are stored as a + * continuous byte array. * - * TODO(davies): use BytesToBytesMap + * It's serialized in the following format: + * [number of keys] + * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] + * ... + * + * All the values are serialized as following: + * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] + * ... */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private[joins] def this() = this(null) // Needed for serialization + + // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + @transient private[this] var binaryMap: BytesToBytesMap = _ - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - // Thanks to type eraser - hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + + if (binaryMap != null) { + // Used in Broadcast join + val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes) + if (loc.isDefined) { + val buffer = CompactBuffer[UnsafeRow]() + + val base = loc.getValueAddress.getBaseObject + var offset = loc.getValueAddress.getBaseOffset + val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + while (offset < last) { + val numFields = PlatformDependent.UNSAFE.getInt(base, offset) + val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + offset += 8 + + val row = new UnsafeRow + row.pointTo(base, offset, numFields, sizeInBytes) + buffer += row + offset += sizeInBytes + } + buffer + } else { + null + } + + } else { + // Use the JavaHashMap in Local mode or ShuffleHashJoin + hashTable.get(unsafeKey) + } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.size) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.size) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 + } + } } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + val nKeys = in.readInt() + // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory + val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision + + var i = 0 + var keyBuffer = new Array[Byte](1024) + var valuesBuffer = new Array[Byte](1024) + while (i < nKeys) { + val keySize = in.readInt() + val valuesSize = in.readInt() + if (keySize > keyBuffer.size) { + keyBuffer = new Array[Byte](keySize) + } + in.readFully(keyBuffer, 0, keySize) + if (valuesSize > valuesBuffer.size) { + valuesBuffer = new Array[Byte](valuesSize) + } + in.readFully(valuesBuffer, 0, valuesSize) + + // put it into binary map + val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + assert(!loc.isDefined, "Duplicated key found!") + loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + i += 1 + } } } @@ -181,33 +290,14 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - buildKeys: Seq[Expression], - buildPlan: SparkPlan, - sizeEstimate: Int = 64): HashedRelation = { - val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) - apply(input, boundedKeys, buildPlan.schema, sizeEstimate) - } - - // Used for tests - def apply( - input: Iterator[InternalRow], - buildKeys: Seq[Expression], - rowSchema: StructType, + keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { - // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) - val toUnsafe = UnsafeProjection.create(rowSchema) - val keyGenerator = UnsafeProjection.create(buildKeys) // Create a mapping of buildKeys -> rows while (input.hasNext) { - val currentRow = input.next() - val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { - currentRow.asInstanceOf[UnsafeRow] - } else { - toUnsafe(currentRow) - } + val unsafeRow = input.next().asInstanceOf[UnsafeRow] val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 874712a4e739f..26a664104d6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -46,7 +46,7 @@ case class LeftSemiJoinHash( val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 948d0ccebceb0..5439e10a60b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = buildHashRelation(buildIter) + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index f54f1edd38ec8..d29b593207c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = buildHashRelation(rightIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(rightIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) @@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin( }) case RightOuter => - val hashed = buildHashRelation(leftIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(leftIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9dd2220f0967e..8b1a9b21a96b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.joins +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types.{StructField, StructType, IntegerType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite { } test("UnsafeHashedRelation") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val toUnsafe = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafe(_).copy()).toArray + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val schema = StructType(StructField("a", IntegerType, true) :: Nil) - val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - val toUnsafeKey = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafeKey(_).copy()).toArray assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() assert(hashed.get(unsafeData(2)) === data2) - val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) - .asInstanceOf[UnsafeHashedRelation] + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) } }