From bea4a509b7f6097474dbee17675bd6f13da2db15 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 13:22:21 -0700 Subject: [PATCH 01/14] Unsafe HashJoin --- .../sql/catalyst/expressions/UnsafeRow.java | 32 ++++++- .../sql/catalyst/expressions/Projection.scala | 8 +- .../expressions/UnsafeRowConverter.scala | 4 +- .../execution/joins/BroadcastHashJoin.scala | 11 ++- .../sql/execution/joins/HashedRelation.scala | 94 ++++++++++++++++++- .../execution/joins/ShuffledHashJoin.scala | 12 ++- .../execution/joins/HashedRelationSuite.scala | 48 +++++++--- .../spark/unsafe/bitset/BitSetMethods.java | 2 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 10 +- 9 files changed, 194 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87294a0e21441..14fba3e655748 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; @@ -345,7 +346,7 @@ public double getDouble(int i) { * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override - public InternalRow copy() { + public UnsafeRow copy() { if (pool != null) { throw new UnsupportedOperationException( "Copy is not supported for UnsafeRows that use object pools"); @@ -365,6 +366,33 @@ public InternalRow copy() { } } + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeRow) { + UnsafeRow o = (UnsafeRow) other; + return ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + // This is for debugging + @Override + public String toString(){ + StringBuilder build = new StringBuilder("["); + for (int i = 0; i < sizeInBytes; i += 8) { + build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(','); + } + build.append(']'); + return build.toString(); + } + @Override public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 24b01ea55110e..43578b52c0026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -85,8 +85,12 @@ abstract class UnsafeProjection extends Projection { object UnsafeProjection { def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) - def create(fields: Seq[DataType]): UnsafeProjection = { + def create(fields: Array[DataType]): UnsafeProjection = { val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + create(exprs) + } + + def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } } @@ -96,6 +100,8 @@ object UnsafeProjection { */ case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => new BoundReference(idx, dt, true) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 885ab091fcdf5..702deb04acb67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { /** * Function for writing a column into an UnsafeRow. */ -private abstract class UnsafeColumnWriter { +abstract class UnsafeColumnWriter { /** * Write a value into an UnsafeRow. * @@ -130,7 +130,7 @@ private abstract class UnsafeColumnWriter { def getSize(source: InternalRow, column: Int): Int } -private object UnsafeColumnWriter { +object UnsafeColumnWriter { def forType(dataType: DataType): UnsafeColumnWriter = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 7ffdce60d2955..03a51afa6f555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -62,7 +62,14 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + val hashed = if (left.codegenEnabled && + buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeHashedRelation(input.iterator, + buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), + buildPlan.schema) + } else { + HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + } sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6b51f5d4151d3..0e1089bf4b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow, Projection} import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } - // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. @@ -148,3 +148,91 @@ private[joins] object HashedRelation { } } } + + +/** + * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a + * sequence of values. + * + * TODO(davies): use BytesToBytesMap + */ +private[joins] final class UnsafeHashedRelation( + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]], + private var keyTypes: Array[DataType]) + extends HashedRelation with Externalizable { + + def this() = this(null, null) // Needed for serialization + + // UnsafeProjection is not thread safe + @transient lazy val keyProjection = new ThreadLocal[UnsafeProjection] + + override def get(key: InternalRow): CompactBuffer[InternalRow] = { + val unsafeKey = if (key.isInstanceOf[UnsafeRow]) { + key.asInstanceOf[UnsafeRow] + } else { + var proj = keyProjection.get() + if (proj eq null) { + proj = UnsafeProjection.create(keyTypes) + keyProjection.set(proj) + } + proj(key) + } + // reply on type erasure in Scala + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(keyTypes)) + val bytes = SparkSqlSerializer.serialize(hashTable) + println(s"before write ${hashTable}") + println(s"write bytes ${bytes.toString}") + writeBytes(out, bytes) + } + + override def readExternal(in: ObjectInput): Unit = { + keyTypes = SparkSqlSerializer.deserialize(readBytes(in)) + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + println(s"loaded ${hashTable}") + } +} + +private[joins] object UnsafeHashedRelation { + + def apply( + input: Iterator[InternalRow], + buildKey: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use BytesToBytesMap. + val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + var currentRow: InternalRow = null + val rowProj = UnsafeProjection.create(rowSchema) + val keyGenerator = UnsafeProjection.create(buildKey) + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { + currentRow.asInstanceOf[UnsafeRow] + } else { + rowProj(currentRow) + } + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(rowKey.copy(), newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + val keySchema = buildKey.map(_.dataType).toArray + new UnsafeHashedRelation(hashTable, keySchema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2a..8f5b77bdfe217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -44,8 +44,16 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val codegenEnabled = left.codegenEnabled buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + val hashed = if (codegenEnabled && + buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeHashedRelation(buildIter, + buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), + buildPlan.schema) + } else { + HashedRelation(buildIter, buildSideKeyGenerator) + } hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9d9858b1c6151..28194df7be906 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} import org.apache.spark.util.collection.CompactBuffer @@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) } test("UniqueKeyHashedRelation") { @@ -49,15 +51,39 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + } + + test("UnsafeHashedRelation") { + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + val toUnsafe = UnsafeProjection.create(schema) + assert(hashed.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0)))) + assert(hashed.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1)))) + assert(hashed.get(InternalRow(10)) === null) + + val data2 = CompactBuffer[InternalRow](toUnsafe(data(2)).copy()) + data2 += toUnsafe(data(2)).copy() + assert(hashed.get(data(2)) === data2) + + val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) + .asInstanceOf[UnsafeHashedRelation] + assert(hashed2.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0)))) + assert(hashed2.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1)))) + assert(hashed2.get(InternalRow(10)) === null) + assert(hashed2.get(data(2)) === data2) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e62..a936a9b81e16e 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -72,7 +72,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { */ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; - for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { + for (int i = 0; i < bitSetWidthInWords; i += 8 * WORD_SIZE, addr += WORD_SIZE) { if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { return true; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 85cd02469adb7..61f483ced3217 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -44,12 +44,16 @@ public int hashInt(int input) { return fmix(h1, 4); } - public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; - for (int offset = 0; offset < lengthInBytes; offset += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } From 95d0762b1e024f539235dbde17efa983bfbafed7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 13:46:53 -0700 Subject: [PATCH 02/14] remove println --- .../org/apache/spark/sql/execution/joins/HashedRelation.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0e1089bf4b643..43be2dddcc603 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -184,15 +184,12 @@ private[joins] final class UnsafeHashedRelation( override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(keyTypes)) val bytes = SparkSqlSerializer.serialize(hashTable) - println(s"before write ${hashTable}") - println(s"write bytes ${bytes.toString}") writeBytes(out, bytes) } override def readExternal(in: ObjectInput): Unit = { keyTypes = SparkSqlSerializer.deserialize(readBytes(in)) hashTable = SparkSqlSerializer.deserialize(readBytes(in)) - println(s"loaded ${hashTable}") } } From 6acbb11a0c751e8d76bd97404148d09da1b8933c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 20 Jul 2015 12:21:39 -0700 Subject: [PATCH 03/14] fix tests --- .../sql/catalyst/expressions/UnsafeRow.java | 17 +++++- .../catalyst/expressions/BoundAttribute.scala | 19 ++++++- .../sql/catalyst/expressions/Projection.scala | 4 ++ .../expressions/StringFunctionsSuite.scala | 2 +- .../execution/joins/BroadcastHashJoin.scala | 11 +--- .../joins/BroadcastHashOuterJoin.scala | 31 +++-------- .../spark/sql/execution/joins/HashJoin.scala | 10 ++++ .../sql/execution/joins/HashOuterJoin.scala | 55 +++++++++++++++++-- .../sql/execution/joins/HashedRelation.scala | 37 ++++++++++--- .../execution/joins/ShuffledHashJoin.scala | 11 +--- .../joins/ShuffledHashOuterJoin.scala | 13 +++-- .../spark/unsafe/bitset/BitSetMethods.java | 2 +- 12 files changed, 146 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 14fba3e655748..4e0cb35ac12af 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -381,6 +381,21 @@ public boolean equals(Object other) { return false; } + /** + * Returns the underline bytes for this UnsafeRow. + */ + public byte[] getBytes() { + if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject).length == sizeInBytes)) { + return (byte[]) baseObject; + } else { + byte[] bytes = new byte[sizeInBytes]; + PlatformDependent.copyMemory(baseObject, baseOffset, bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + return bytes; + } + } + // This is for debugging @Override public String toString(){ @@ -395,6 +410,6 @@ public String toString(){ @Override public boolean anyNull() { - return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b09aea03318da..639973c68eac5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ /** @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal]" - override def eval(input: InternalRow): Any = input(ordinal) + // Use special getter for primitive types (for UnsafeRow) + override def eval(input: InternalRow): Any = { + if (input.isNullAt(ordinal)) { + null + } else { + dataType match { + case BooleanType => input.getBoolean(ordinal) + case ByteType => input.getByte(ordinal) + case ShortType => input.getShort(ordinal) + case IntegerType | DateType => input.getInt(ordinal) + case LongType | TimestampType => input.getLong(ordinal) + case FloatType => input.getFloat(ordinal) + case DoubleType => input.getDouble(ordinal) + case _ => input.get(ordinal) + } + } + } override def name: String = s"i[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 43578b52c0026..b2e0d3b0b187a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -93,6 +93,10 @@ object UnsafeProjection { def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } + + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { + create(exprs.map(BindReferences.bindReference(_, inputSchema))) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 5d7763bedf6bd..d3dde7191d6cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -302,7 +302,7 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("FORMAT") { val f = 'f.string.at(0) val d1 = 'd.int.at(1) - val s1 = 's.int.at(2) + val s1 = 's.string.at(2) val row1 = create_row("aa%d%s", 12, "cc") val row2 = create_row(null, 12, "cc") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 03a51afa6f555..abaa4a6ce86a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -62,14 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = if (left.codegenEnabled && - buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { - UnsafeHashedRelation(input.iterator, - buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), - buildPlan.schema) - } else { - HashedRelation(input.iterator, buildSideKeyGenerator, input.length) - } + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab757fc7de6cd..b9326166338c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.joins +import scala.concurrent._ +import scala.concurrent.duration._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils -import scala.collection.JavaConversions._ -import scala.concurrent._ -import scala.concurrent.duration._ - /** * :: DeveloperApi :: * Performs a outer hash join for two child relations. When the output RDD of this operator is @@ -58,28 +57,12 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - private[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - - private[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - // buildHashTable uses code-generated rows as keys, which are not serializable - val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)) + val hashed = buildHashRelation(input.iterator) + //val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) @@ -96,14 +79,14 @@ case class BroadcastHashOuterJoin( streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ff85ea3f6a410..79ff5b6a2b629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -103,4 +103,14 @@ trait HashJoin { } } } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (self.codegenEnabled && buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeHashedRelation(buildIter, + buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), + buildPlan.schema) + } else { + HashedRelation(buildIter, buildSideKeyGenerator) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 0522ee85eeb8a..091d512190493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -38,7 +38,7 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan -override def outputPartitioning: Partitioning = joinType match { + override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -59,6 +59,31 @@ override def outputPartitioning: Partitioning = joinType match { } } + protected[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + protected[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + protected[this] def streamedKeyGenerator(): Projection = { + if (self.codegenEnabled && + streamedKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newProjection(streamedKeys, streamedPlan.output) + } + } + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -77,8 +102,12 @@ override def outputPartitioning: Partitioning = joinType match { rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + } + } else { + List() } if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil @@ -98,9 +127,13 @@ override def outputPartitioning: Partitioning = joinType match { joinedRow: JoinedRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy() + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => + joinedRow.copy() + } + } else { + List() } if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil @@ -179,4 +212,14 @@ override def outputPartitioning: Partitioning = joinType match { hashTable } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (self.codegenEnabled && buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeHashedRelation(buildIter, + buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), + buildPlan.schema) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 43be2dddcc603..b0bf7ad6d745f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow, Projection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -158,13 +158,15 @@ private[joins] object HashedRelation { */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]], - private var keyTypes: Array[DataType]) + private var keyTypes: Array[DataType], + private var rowTypes: Array[DataType]) extends HashedRelation with Externalizable { - def this() = this(null, null) // Needed for serialization + def this() = this(null, null, null) // Needed for serialization // UnsafeProjection is not thread safe @transient lazy val keyProjection = new ThreadLocal[UnsafeProjection] + @transient lazy val fromUnsafeProjection = new ThreadLocal[FromUnsafeProjection] override def get(key: InternalRow): CompactBuffer[InternalRow] = { val unsafeKey = if (key.isInstanceOf[UnsafeRow]) { @@ -177,18 +179,38 @@ private[joins] final class UnsafeHashedRelation( } proj(key) } - // reply on type erasure in Scala - hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + + val values = hashTable.get(unsafeKey) + // Return GenericInternalRow to work with other JoinRow, which + // TODO(davies): return UnsafeRow once we have UnsafeJoinRow. + if (values != null) { + var proj = fromUnsafeProjection.get() + if (proj eq null) { + proj = new FromUnsafeProjection(rowTypes) + fromUnsafeProjection.set(proj) + } + var i = 0 + val ret = new CompactBuffer[InternalRow] + while (i < values.length) { + ret += proj(values(i)).copy() + i += 1 + } + ret + } else { + null + } } override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(keyTypes)) + writeBytes(out, SparkSqlSerializer.serialize(rowTypes)) val bytes = SparkSqlSerializer.serialize(hashTable) writeBytes(out, bytes) } override def readExternal(in: ObjectInput): Unit = { keyTypes = SparkSqlSerializer.deserialize(readBytes(in)) + rowTypes = SparkSqlSerializer.deserialize(readBytes(in)) hashTable = SparkSqlSerializer.deserialize(readBytes(in)) } } @@ -229,7 +251,8 @@ private[joins] object UnsafeHashedRelation { } } - val keySchema = buildKey.map(_.dataType).toArray - new UnsafeHashedRelation(hashTable, keySchema) + val keyTypes = buildKey.map(_.dataType).toArray + val rowTypes = rowSchema.fields.map(_.dataType).toArray + new UnsafeHashedRelation(hashTable, keyTypes, rowTypes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 8f5b77bdfe217..fc7e1e654e7ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -46,14 +46,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { val codegenEnabled = left.codegenEnabled buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = if (codegenEnabled && - buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { - UnsafeHashedRelation(buildIter, - buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), - buildPlan.schema) - } else { - HashedRelation(buildIter, buildSideKeyGenerator) - } + val hashed = buildHashRelation(buildIter) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ab0a6ad56acde..f54f1edd38ec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) + val hashed = buildHashRelation(rightIter) + val keyGenerator = streamedKeyGenerator() leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) }) case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) + val hashed = buildHashRelation(leftIter) + val keyGenerator = streamedKeyGenerator() rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) }) case FullOuter => + // TODO(davies): use UnsafeRow val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index a936a9b81e16e..dcf0bc1e065bb 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -72,7 +72,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { */ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; - for (int i = 0; i < bitSetWidthInWords; i += 8 * WORD_SIZE, addr += WORD_SIZE) { + for (int i = 0; i < bitSetWidthInWords; i += 1, addr += WORD_SIZE) { if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { return true; } From 184b8524962a114637e13e21feba54a1674fdcb4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 20 Jul 2015 12:49:54 -0700 Subject: [PATCH 04/14] fix style --- .../spark/sql/execution/joins/BroadcastHashOuterJoin.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index b9326166338c7..473e0d0120b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -62,7 +62,6 @@ case class BroadcastHashOuterJoin( // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() val hashed = buildHashRelation(input.iterator) - //val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) From 60371f28e1660078d43cd58d3ec2b20c29b2ae23 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 20 Jul 2015 14:59:34 -0700 Subject: [PATCH 05/14] use UnsafeRow in SemiJoin --- .../execution/UnsafeExternalRowSorter.java | 10 +-- .../sql/catalyst/expressions/Projection.scala | 3 + .../joins/BroadcastLeftSemiJoinHash.scala | 7 +-- .../spark/sql/execution/joins/HashJoin.scala | 6 +- .../sql/execution/joins/HashOuterJoin.scala | 9 +-- .../sql/execution/joins/HashSemiJoin.scala | 63 +++++++++++-------- .../execution/joins/LeftSemiJoinHash.scala | 4 +- .../execution/joins/HashedRelationSuite.scala | 12 ++-- 8 files changed, 64 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d1d81c87bb052..39fd6e1bc6d13 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,11 +28,10 @@ import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; @@ -176,12 +175,7 @@ public Iterator sort(Iterator inputIterator) throws IO */ public static boolean supportsSchema(StructType schema) { // TODO: add spilling note to explain why we do this for now: - for (StructField field : schema.fields()) { - if (!UnsafeColumnWriter.canEmbed(field.dataType())) { - return false; - } - } - return true; + return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b2e0d3b0b187a..f8ca86cda4bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -83,6 +83,9 @@ abstract class UnsafeProjection extends Projection { } object UnsafeProjection { + def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) + def canSupport(types: Seq[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) def create(fields: Array[DataType]): UnsafeProjection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2750f58b005ac..91477ed91f4ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -40,16 +40,15 @@ case class BroadcastLeftSemiJoinHash( val buildIter = right.execute().map(_.copy()).collect().toIterator if (condition.isEmpty) { - // rowKey may be not serializable (from codegen) - val hashSet = buildKeyHashSet(buildIter, copy = true) + val hashSet = buildKeyHashSet(buildIter) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - val broadcastedRelation = sparkContext.broadcast(hashRelation) + val hashed = buildHashRelation(buildIter) + val broadcastedRelation = sparkContext.broadcast(hashed) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 79ff5b6a2b629..b5dc24bca85b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -105,8 +105,10 @@ trait HashJoin { } protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (self.codegenEnabled && buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { - UnsafeHashedRelation(buildIter, + if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType)) + && UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) { + UnsafeHashedRelation( + buildIter, buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), buildPlan.schema) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index dfbe5b9622cd4..8bd13781ebaa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -76,8 +76,7 @@ trait HashOuterJoin { } protected[this] def streamedKeyGenerator(): Projection = { - if (self.codegenEnabled && - streamedKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + if (self.codegenEnabled && UnsafeProjection.canSupport(streamedKeys.map(_.dataType))) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) @@ -213,8 +212,10 @@ trait HashOuterJoin { } protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (self.codegenEnabled && buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { - UnsafeHashedRelation(buildIter, + if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType)) + && UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) { + UnsafeHashedRelation( + buildIter, buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), buildPlan.schema) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 1b983bc3a90f9..338515fad9bdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -32,34 +32,33 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - @transient protected lazy val rightKeyGenerator: Projection = - newProjection(rightKeys, right.output) - - @transient protected lazy val leftKeyGenerator: () => MutableProjection = - newMutableProjection(leftKeys, left.output) + @transient protected lazy val leftKeyGenerator: Projection = + if (canUseUnsafeRow) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newMutableProjection(leftKeys, left.output)() + } @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow], - copy: Boolean): java.util.Set[InternalRow] = { + protected def buildKeyHashSet( buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() var currentRow: InternalRow = null // Create a Hash set of buildKeys + val rightKeyGenerator = if (canUseUnsafeRow) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = rightKeyGenerator(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { - if (copy) { - hashSet.add(rowKey.copy()) - } else { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey) - } + hashSet.add(rowKey.copy()) } } } @@ -67,25 +66,39 @@ trait HashSemiJoin { } protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() - val joinedRow = new JoinedRow + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator streamIter.filter(current => { - lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) - !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { - (build: InternalRow) => boundCondition(joinedRow(current, build)) - } + val key = joinKeys(current) + !key.anyNull && hashSet.contains(key) }) } + private lazy val canUseUnsafeRow: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(rightKeys.map(_.dataType)) + && UnsafeProjection.canSupport(right.output.map(_.dataType))) + } + + protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (canUseUnsafeRow) { + UnsafeHashedRelation(buildIter, rightKeys, right.schema) + } else { + HashedRelation(buildIter, newProjection(rightKeys, right.output)) + } + } + protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter(current => { - !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) + val key = joinKeys(current) + lazy val rowBuffer = hashedRelation.get(key) + !key.anyNull && rowBuffer != null && rowBuffer.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + } }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 9eaac817d9268..874712a4e739f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -43,10 +43,10 @@ case class LeftSemiJoinHash( protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter, copy = false) + val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 28194df7be906..417e5210da25b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -70,9 +70,11 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - val toUnsafe = UnsafeProjection.create(schema) - assert(hashed.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0)))) - assert(hashed.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1)))) + // TODO: enable this once we don't return generic row from UnsafeHashRelation.get() + // val toUnsafe = UnsafeProjection.create(schema) + val toUnsafe = (x: InternalRow) => x + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](toUnsafe(data(0)))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](toUnsafe(data(1)))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](toUnsafe(data(2)).copy()) @@ -81,8 +83,8 @@ class HashedRelationSuite extends SparkFunSuite { val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) .asInstanceOf[UnsafeHashedRelation] - assert(hashed2.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0)))) - assert(hashed2.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1)))) + assert(hashed2.get(data(0)) === CompactBuffer[InternalRow](toUnsafe(data(0)))) + assert(hashed2.get(data(1)) === CompactBuffer[InternalRow](toUnsafe(data(1)))) assert(hashed2.get(InternalRow(10)) === null) assert(hashed2.get(data(2)) === data2) } From ab1690f4f687c1990cf6a268ac418a867f0eb928 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 20 Jul 2015 15:16:39 -0700 Subject: [PATCH 06/14] address comments --- .../sql/catalyst/expressions/UnsafeRow.java | 9 ++++---- .../sql/catalyst/expressions/Projection.scala | 21 ++++++++++++++++++- .../joins/BroadcastLeftSemiJoinHash.scala | 4 ++-- .../spark/sql/execution/joins/HashJoin.scala | 4 ++-- .../sql/execution/joins/HashOuterJoin.scala | 14 ++++++------- .../sql/execution/joins/HashSemiJoin.scala | 4 ++-- .../sql/execution/joins/HashedRelation.scala | 15 +++++++------ .../execution/joins/ShuffledHashJoin.scala | 1 - 8 files changed, 45 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4e0cb35ac12af..9a942102a920f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -375,14 +375,15 @@ public int hashCode() { public boolean equals(Object other) { if (other instanceof UnsafeRow) { UnsafeRow o = (UnsafeRow) other; - return ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, - sizeInBytes); + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); } return false; } /** - * Returns the underline bytes for this UnsafeRow. + * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET @@ -398,7 +399,7 @@ public byte[] getBytes() { // This is for debugging @Override - public String toString(){ + public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index f8ca86cda4bca..69758e653eba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -83,20 +83,39 @@ abstract class UnsafeProjection extends Projection { } object UnsafeProjection { + + /* + * Returns whether UnsafeProjection can support given StructType, Array[DataType] or + * Seq[Expression]. + */ def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(types: Seq[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + /** + * Returns an UnsafeProjection for given StructType. + */ def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ def create(fields: Array[DataType]): UnsafeProjection = { val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) create(exprs) } + /** + * Returns an UnsafeProjection for given sequence of Expressions (bounded). + */ def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } + /** + * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. + */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 91477ed91f4ba..f71c0ce352904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -47,8 +47,8 @@ case class BroadcastLeftSemiJoinHash( hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashed = buildHashRelation(buildIter) - val broadcastedRelation = sparkContext.broadcast(hashed) + val hashRelation = buildHashRelation(buildIter) + val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b5dc24bca85b0..3e13c20c9b662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -105,8 +105,8 @@ trait HashJoin { } protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType)) - && UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) { + if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(buildPlan.schema)) { UnsafeHashedRelation( buildIter, buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 8bd13781ebaa8..fee75e53cdb22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,7 +64,7 @@ trait HashOuterJoin { case LeftOuter => (right, left) case x => throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") + s"HashOuterJoin should not take $x as the JoinType") } protected[this] lazy val (buildKeys, streamedKeys) = joinType match { @@ -72,11 +72,11 @@ trait HashOuterJoin { case LeftOuter => (rightKeys, leftKeys) case x => throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") + s"HashOuterJoin should not take $x as the JoinType") } protected[this] def streamedKeyGenerator(): Projection = { - if (self.codegenEnabled && UnsafeProjection.canSupport(streamedKeys.map(_.dataType))) { + if (self.codegenEnabled && UnsafeProjection.canSupport(streamedKeys)) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) @@ -105,7 +105,7 @@ trait HashOuterJoin { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } } else { - List() + List.empty } if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil @@ -131,7 +131,7 @@ trait HashOuterJoin { joinedRow.copy() } } else { - List() + List.empty } if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil @@ -212,8 +212,8 @@ trait HashOuterJoin { } protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys.map(_.dataType)) - && UnsafeProjection.canSupport(buildPlan.output.map(_.dataType))) { + if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(buildPlan.schema)) { UnsafeHashedRelation( buildIter, buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 338515fad9bdf..92a9a504cadb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -76,8 +76,8 @@ trait HashSemiJoin { } private lazy val canUseUnsafeRow: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(rightKeys.map(_.dataType)) - && UnsafeProjection.canSupport(right.output.map(_.dataType))) + (self.codegenEnabled && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(right.schema)) } protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b0bf7ad6d745f..220e9dda504e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -204,8 +204,7 @@ private[joins] final class UnsafeHashedRelation( override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(keyTypes)) writeBytes(out, SparkSqlSerializer.serialize(rowTypes)) - val bytes = SparkSqlSerializer.serialize(hashTable) - writeBytes(out, bytes) + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) } override def readExternal(in: ObjectInput): Unit = { @@ -218,15 +217,15 @@ private[joins] final class UnsafeHashedRelation( private[joins] object UnsafeHashedRelation { def apply( - input: Iterator[InternalRow], - buildKey: Seq[Expression], - rowSchema: StructType, - sizeEstimate: Int = 64): HashedRelation = { + input: Iterator[InternalRow], + buildKey: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int = 64): HashedRelation = { // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) var currentRow: InternalRow = null - val rowProj = UnsafeProjection.create(rowSchema) + val toUnsafe = UnsafeProjection.create(rowSchema) val keyGenerator = UnsafeProjection.create(buildKey) // Create a mapping of buildKeys -> rows @@ -235,7 +234,7 @@ private[joins] object UnsafeHashedRelation { val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { currentRow.asInstanceOf[UnsafeRow] } else { - rowProj(currentRow) + toUnsafe(currentRow) } val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index fc7e1e654e7ec..948d0ccebceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -44,7 +44,6 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { - val codegenEnabled = left.codegenEnabled buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = buildHashRelation(buildIter) hashJoin(streamIter, hashed) From 1a40f02df481263d7dc25aa5b96157e2f6a5380f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 20 Jul 2015 16:15:36 -0700 Subject: [PATCH 07/14] refactor --- .../joins/BroadcastHashOuterJoin.scala | 2 +- .../spark/sql/execution/joins/HashJoin.scala | 28 ++++++++----- .../sql/execution/joins/HashOuterJoin.scala | 16 ++++---- .../sql/execution/joins/HashSemiJoin.scala | 2 +- .../sql/execution/joins/HashedRelation.scala | 40 +++++++++---------- .../execution/joins/HashedRelationSuite.scala | 27 ++++++------- 6 files changed, 59 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 473e0d0120b46..c9d1a880f4ef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -71,7 +71,7 @@ case class BroadcastHashOuterJoin( streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value - val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + val keyGenerator = streamedKeyGenerator joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 3e13c20c9b662..51296ee6d7ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -47,8 +47,12 @@ trait HashJoin { @transient protected lazy val buildSideKeyGenerator: Projection = newProjection(buildKeys, buildPlan.output) - @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = - newMutableProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val streamSideKeyGenerator: Projection = + if (canUseUnsafeRow) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newMutableProjection(streamedKeys, streamedPlan.output)() + } protected def hashJoin( streamIter: Iterator[InternalRow], @@ -62,7 +66,7 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator override final def hasNext: Boolean = (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || @@ -89,8 +93,9 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashedRelation.get(joinKeys.currentValue) + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) } } @@ -104,13 +109,14 @@ trait HashJoin { } } + protected[this] def canUseUnsafeRow: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(buildPlan.schema)) + } + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(buildPlan.schema)) { - UnsafeHashedRelation( - buildIter, - buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), - buildPlan.schema) + if (canUseUnsafeRow) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) } else { HashedRelation(buildIter, buildSideKeyGenerator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index fee75e53cdb22..e322fd8cbced1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -76,7 +76,7 @@ trait HashOuterJoin { } protected[this] def streamedKeyGenerator(): Projection = { - if (self.codegenEnabled && UnsafeProjection.canSupport(streamedKeys)) { + if (canUseUnsafeRow) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) @@ -191,6 +191,7 @@ trait HashOuterJoin { } } + // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { @@ -211,13 +212,14 @@ trait HashOuterJoin { hashTable } + protected[this] def canUseUnsafeRow: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(buildPlan.schema)) + } + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(buildPlan.schema)) { - UnsafeHashedRelation( - buildIter, - buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), - buildPlan.schema) + if (canUseUnsafeRow) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) } else { HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 92a9a504cadb4..3e4119e62c263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -82,7 +82,7 @@ trait HashSemiJoin { protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { if (canUseUnsafeRow) { - UnsafeHashedRelation(buildIter, rightKeys, right.schema) + UnsafeHashedRelation(buildIter, rightKeys, right) } else { HashedRelation(buildIter, newProjection(rightKeys, right.output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 220e9dda504e3..12c800668148b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -22,8 +22,8 @@ import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} +import org.apache.spark.sql.types.{StructType, DataType} import org.apache.spark.util.collection.CompactBuffer @@ -158,27 +158,16 @@ private[joins] object HashedRelation { */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]], - private var keyTypes: Array[DataType], private var rowTypes: Array[DataType]) extends HashedRelation with Externalizable { - def this() = this(null, null, null) // Needed for serialization + def this() = this(null, null) // Needed for serialization // UnsafeProjection is not thread safe - @transient lazy val keyProjection = new ThreadLocal[UnsafeProjection] @transient lazy val fromUnsafeProjection = new ThreadLocal[FromUnsafeProjection] override def get(key: InternalRow): CompactBuffer[InternalRow] = { - val unsafeKey = if (key.isInstanceOf[UnsafeRow]) { - key.asInstanceOf[UnsafeRow] - } else { - var proj = keyProjection.get() - if (proj eq null) { - proj = UnsafeProjection.create(keyTypes) - keyProjection.set(proj) - } - proj(key) - } + val unsafeKey = key.asInstanceOf[UnsafeRow] val values = hashTable.get(unsafeKey) // Return GenericInternalRow to work with other JoinRow, which @@ -202,13 +191,11 @@ private[joins] final class UnsafeHashedRelation( } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(keyTypes)) writeBytes(out, SparkSqlSerializer.serialize(rowTypes)) writeBytes(out, SparkSqlSerializer.serialize(hashTable)) } override def readExternal(in: ObjectInput): Unit = { - keyTypes = SparkSqlSerializer.deserialize(readBytes(in)) rowTypes = SparkSqlSerializer.deserialize(readBytes(in)) hashTable = SparkSqlSerializer.deserialize(readBytes(in)) } @@ -218,15 +205,25 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - buildKey: Seq[Expression], - rowSchema: StructType, + buildKeys: Seq[Expression], + buildPlan: SparkPlan, sizeEstimate: Int = 64): HashedRelation = { + val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) + apply(input, boundedKeys, buildPlan.schema, sizeEstimate) + } + + // Used for tests + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int): HashedRelation = { // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) var currentRow: InternalRow = null val toUnsafe = UnsafeProjection.create(rowSchema) - val keyGenerator = UnsafeProjection.create(buildKey) + val keyGenerator = UnsafeProjection.create(buildKeys) // Create a mapping of buildKeys -> rows while (input.hasNext) { @@ -250,8 +247,7 @@ private[joins] object UnsafeHashedRelation { } } - val keyTypes = buildKey.map(_.dataType).toArray val rowTypes = rowSchema.fields.map(_.dataType).toArray - new UnsafeHashedRelation(hashTable, keyTypes, rowTypes) + new UnsafeHashedRelation(hashTable, rowTypes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 417e5210da25b..6b9800e763d01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -67,25 +67,24 @@ class HashedRelationSuite extends SparkFunSuite { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val buildKey = Seq(BoundReference(0, IntegerType, false)) val schema = StructType(StructField("a", IntegerType, true) :: Nil) - val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - // TODO: enable this once we don't return generic row from UnsafeHashRelation.get() - // val toUnsafe = UnsafeProjection.create(schema) - val toUnsafe = (x: InternalRow) => x - assert(hashed.get(data(0)) === CompactBuffer[InternalRow](toUnsafe(data(0)))) - assert(hashed.get(data(1)) === CompactBuffer[InternalRow](toUnsafe(data(1)))) - assert(hashed.get(InternalRow(10)) === null) + val toUnsafeKey = UnsafeProjection.create(schema) + val keys = data.map(toUnsafeKey(_).copy()).toArray + assert(hashed.get(keys(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(keys(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) - val data2 = CompactBuffer[InternalRow](toUnsafe(data(2)).copy()) - data2 += toUnsafe(data(2)).copy() - assert(hashed.get(data(2)) === data2) + val data2 = CompactBuffer[InternalRow](data(2).copy()) + data2 += data(2).copy() + assert(hashed.get(keys(2)) === data2) val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) .asInstanceOf[UnsafeHashedRelation] - assert(hashed2.get(data(0)) === CompactBuffer[InternalRow](toUnsafe(data(0)))) - assert(hashed2.get(data(1)) === CompactBuffer[InternalRow](toUnsafe(data(1)))) - assert(hashed2.get(InternalRow(10)) === null) - assert(hashed2.get(data(2)) === data2) + assert(hashed2.get(keys(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed2.get(keys(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(keys(2)) === data2) } } From 0f4380d79db4d09628cf3365b9bf2c39a5b8b107 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 09:51:57 -0700 Subject: [PATCH 08/14] ada a comment --- .../org/apache/spark/sql/execution/joins/HashedRelation.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 12c800668148b..173e93c0b7067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -154,6 +154,8 @@ private[joins] object HashedRelation { * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a * sequence of values. * + * It's safe to to use inside a broadcast, and safe to read by multiple threads. + * * TODO(davies): use BytesToBytesMap */ private[joins] final class UnsafeHashedRelation( From ca2b40fba72bd698a8638a8d9aaf08fc2ddf8a2e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 10:38:21 -0700 Subject: [PATCH 09/14] revert unrelated change --- .../org/apache/spark/sql/execution/joins/HashSemiJoin.scala | 2 +- .../main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 3e4119e62c263..c67dad98ddcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -42,7 +42,7 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() var currentRow: InternalRow = null diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index dcf0bc1e065bb..27462c7fa5e62 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -72,7 +72,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { */ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; - for (int i = 0; i < bitSetWidthInWords; i += 1, addr += WORD_SIZE) { + for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { return true; } From 9481ae89a98caaf63d2b7eb27ee25a0c58542ce0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 12:46:13 -0700 Subject: [PATCH 10/14] return UnsafeRow after join() --- .../spark/sql/execution/joins/HashJoin.scala | 31 +++++++++----- .../sql/execution/joins/HashOuterJoin.scala | 40 ++++++++++++------- .../sql/execution/joins/HashSemiJoin.scala | 33 +++++++++------ .../sql/execution/joins/HashedRelation.scala | 39 +++--------------- .../sql/execution/rowFormatConverters.scala | 3 ++ .../org/apache/spark/sql/UnsafeRowSuite.scala | 4 +- .../execution/joins/HashedRelationSuite.scala | 18 ++++----- 7 files changed, 86 insertions(+), 82 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 51296ee6d7ca0..c2da44cf96d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,11 +44,16 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator: Projection = - newProjection(buildKeys, buildPlan.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe @transient protected lazy val streamSideKeyGenerator: Projection = - if (canUseUnsafeRow) { + if (supportUnsafe) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newMutableProjection(streamedKeys, streamedPlan.output)() @@ -65,6 +70,11 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 + private[this] val resultProjection: Projection = if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + ((r: InternalRow) => r).asInstanceOf[Projection] + } private[this] val joinKeys = streamSideKeyGenerator @@ -78,7 +88,11 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 - ret + if (supportUnsafe) { + resultProjection(ret) + } else { + ret + } } /** @@ -109,16 +123,11 @@ trait HashJoin { } } - protected[this] def canUseUnsafeRow: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(buildPlan.schema)) - } - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (canUseUnsafeRow) { + if (supportUnsafe) { UnsafeHashedRelation(buildIter, buildKeys, buildPlan) } else { - HashedRelation(buildIter, buildSideKeyGenerator) + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e322fd8cbced1..e1c05b5c56b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -75,14 +75,31 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && joinType != FullOuter + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe + protected[this] def streamedKeyGenerator(): Projection = { - if (canUseUnsafeRow) { + if (supportUnsafe) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) } } + @transient private[this] lazy val resultProjection: Projection = + if (supportUnsafe) { + // Converted returned JoinRow into UnsafeRow + UnsafeProjection.create(self.schema) + } else { + ((x: InternalRow) => x).asInstanceOf[Projection] + } + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -102,18 +119,18 @@ trait HashOuterJoin { if (!key.anyNull) { val temp = if (rightIter != null) { rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() } } else { List.empty } if (temp.isEmpty) { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } } ret.iterator @@ -128,18 +145,18 @@ trait HashOuterJoin { val temp = if (leftIter != null) { leftIter.collect { case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy() + resultProjection(joinedRow).copy() } } else { List.empty } if (temp.isEmpty) { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } } ret.iterator @@ -212,13 +229,8 @@ trait HashOuterJoin { hashTable } - protected[this] def canUseUnsafeRow: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(buildPlan.schema)) - } - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (canUseUnsafeRow) { + if (supportUnsafe) { UnsafeHashedRelation(buildIter, buildKeys, buildPlan) } else { HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index c67dad98ddcca..23ed796772453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -32,13 +32,29 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(left.schema)) + } + + override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = supportUnsafe + @transient protected lazy val leftKeyGenerator: Projection = - if (canUseUnsafeRow) { + if (supportUnsafe) { UnsafeProjection.create(leftKeys, left.output) } else { newMutableProjection(leftKeys, left.output)() } + @transient protected lazy val rightKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newMutableProjection(rightKeys, right.output)() + } + @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) @@ -47,14 +63,10 @@ trait HashSemiJoin { var currentRow: InternalRow = null // Create a Hash set of buildKeys - val rightKeyGenerator = if (canUseUnsafeRow) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } + val rightKey = rightKeyGenerator while (buildIter.hasNext) { currentRow = buildIter.next() - val rowKey = rightKeyGenerator(currentRow) + val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { @@ -75,13 +87,8 @@ trait HashSemiJoin { }) } - private lazy val canUseUnsafeRow: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(right.schema)) - } - protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (canUseUnsafeRow) { + if (supportUnsafe) { UnsafeHashedRelation(buildIter, rightKeys, right) } else { HashedRelation(buildIter, newProjection(rightKeys, right.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 173e93c0b7067..1b8c90e9a89c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -23,7 +23,7 @@ import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.CompactBuffer @@ -154,51 +154,25 @@ private[joins] object HashedRelation { * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a * sequence of values. * - * It's safe to to use inside a broadcast, and safe to read by multiple threads. - * * TODO(davies): use BytesToBytesMap */ private[joins] final class UnsafeHashedRelation( - private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]], - private var rowTypes: Array[DataType]) + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation with Externalizable { - def this() = this(null, null) // Needed for serialization - - // UnsafeProjection is not thread safe - @transient lazy val fromUnsafeProjection = new ThreadLocal[FromUnsafeProjection] + def this() = this(null) // Needed for serialization override def get(key: InternalRow): CompactBuffer[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - - val values = hashTable.get(unsafeKey) - // Return GenericInternalRow to work with other JoinRow, which - // TODO(davies): return UnsafeRow once we have UnsafeJoinRow. - if (values != null) { - var proj = fromUnsafeProjection.get() - if (proj eq null) { - proj = new FromUnsafeProjection(rowTypes) - fromUnsafeProjection.set(proj) - } - var i = 0 - val ret = new CompactBuffer[InternalRow] - while (i < values.length) { - ret += proj(values(i)).copy() - i += 1 - } - ret - } else { - null - } + // Thanks to type eraser + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(rowTypes)) writeBytes(out, SparkSqlSerializer.serialize(hashTable)) } override def readExternal(in: ObjectInput): Unit = { - rowTypes = SparkSqlSerializer.deserialize(readBytes(in)) hashTable = SparkSqlSerializer.deserialize(readBytes(in)) } } @@ -249,7 +223,6 @@ private[joins] object UnsafeHashedRelation { } } - val rowTypes = rowSchema.fields.map(_.dataType).toArray - new UnsafeHashedRelation(hashTable, rowTypes) + new UnsafeHashedRelation(hashTable) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 421d510e6782d..14a9f0523ae0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + + require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") + override def output: Seq[Attribute] = child.output override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 3854dc1b7a3d1..d36e2639376e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite { test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = - UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row) + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) val bytesFromArrayBackedRow: Array[Byte] = { val baos = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 6b9800e763d01..9dd2220f0967e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -71,20 +71,20 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed.isInstanceOf[UnsafeHashedRelation]) val toUnsafeKey = UnsafeProjection.create(schema) - val keys = data.map(toUnsafeKey(_).copy()).toArray - assert(hashed.get(keys(0)) === CompactBuffer[InternalRow](data(0))) - assert(hashed.get(keys(1)) === CompactBuffer[InternalRow](data(1))) + val unsafeData = data.map(toUnsafeKey(_).copy()).toArray + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) - val data2 = CompactBuffer[InternalRow](data(2).copy()) - data2 += data(2).copy() - assert(hashed.get(keys(2)) === data2) + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) .asInstanceOf[UnsafeHashedRelation] - assert(hashed2.get(keys(0)) === CompactBuffer[InternalRow](data(0))) - assert(hashed2.get(keys(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) - assert(hashed2.get(keys(2)) === data2) + assert(hashed2.get(unsafeData(2)) === data2) } } From a05b4f6e0eb6cf9d983120f470571cb8fdefa3ee Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 13:55:23 -0700 Subject: [PATCH 11/14] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin --- .../expressions/UnsafeRowConverter.scala | 4 +-- .../joins/BroadcastNestedLoopJoin.scala | 35 ++++++++++++------- .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 ++ .../sql/execution/rowFormatConverters.scala | 18 +++++++--- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 702deb04acb67..885ab091fcdf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { /** * Function for writing a column into an UnsafeRow. */ -abstract class UnsafeColumnWriter { +private abstract class UnsafeColumnWriter { /** * Write a value into an UnsafeRow. * @@ -130,7 +130,7 @@ abstract class UnsafeColumnWriter { def getSize(source: InternalRow, column: Int): Int } -object UnsafeColumnWriter { +private object UnsafeColumnWriter { def forType(dataType: DataType): UnsafeColumnWriter = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 60b4266fad8b1..7be83ddfc17a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -44,6 +44,17 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + + @transient private[this] lazy val resultProjection: Projection = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + ((r: InternalRow) => r).asInstanceOf[Projection] + } + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { @@ -74,6 +85,7 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -86,11 +98,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -100,9 +112,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() + matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() + matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -110,12 +122,9 @@ case class BroadcastNestedLoopJoin( } val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } + val allIncludedBroadcastTuples = includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + )(_ ++ _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -127,8 +136,10 @@ case class BroadcastNestedLoopJoin( while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => + buf += resultProjection(new JoinedRow(leftNulls, rel(i))) + case (LeftOuter | FullOuter, BuildLeft) => + buf += resultProjection(new JoinedRow(rel(i), rightNulls)) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index db5be9f453674..4443455ef11fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 14a9f0523ae0d..29f3beb3cb3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -96,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { } case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, then convert everything - // to unsafe rows - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows if all the schema of them are support by UnsafeRow + if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } } } } else { From 84c98073b1a4a6ec6a461d6826dd04ec87a6eb96 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 14:47:20 -0700 Subject: [PATCH 12/14] address comments --- .../org/apache/spark/sql/execution/joins/HashJoin.scala | 6 +----- .../org/apache/spark/sql/execution/joins/HashSemiJoin.scala | 4 ++-- .../apache/spark/sql/execution/joins/HashedRelation.scala | 3 +-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index c2da44cf96d4d..9e7ef7aa76b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -88,11 +88,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 - if (supportUnsafe) { - resultProjection(ret) - } else { - ret - } + resultProjection(ret) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 23ed796772453..7f49264d40354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -100,12 +100,12 @@ trait HashSemiJoin { hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow - streamIter.filter(current => { + streamIter.filter { current => val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) !key.anyNull && rowBuffer != null && rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) } - }) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1b8c90e9a89c5..8d5731afd59b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -197,13 +197,12 @@ private[joins] object UnsafeHashedRelation { // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) - var currentRow: InternalRow = null val toUnsafe = UnsafeProjection.create(rowSchema) val keyGenerator = UnsafeProjection.create(buildKeys) // Create a mapping of buildKeys -> rows while (input.hasNext) { - currentRow = input.next() + val currentRow = input.next() val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { currentRow.asInstanceOf[UnsafeRow] } else { From dede020a7dd8977ed0a825dfad9fc96fb7b0a53e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 17:12:33 -0700 Subject: [PATCH 13/14] fix test --- .../sql/execution/joins/BroadcastNestedLoopJoin.scala | 5 ++--- .../apache/spark/sql/execution/joins/HashJoin.scala | 11 ++++++----- .../spark/sql/execution/joins/HashOuterJoin.scala | 5 ++--- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 7be83ddfc17a4..384eb6b532e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,13 +47,12 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: Projection = { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { - ((r: InternalRow) => r).asInstanceOf[Projection] + (r: InternalRow) => r } - } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 9e7ef7aa76b52..645eb6e452586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -70,11 +70,12 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 - private[this] val resultProjection: Projection = if (supportUnsafe) { - UnsafeProjection.create(self.schema) - } else { - ((r: InternalRow) => r).asInstanceOf[Projection] - } + private[this] val resultProjection: InternalRow => InternalRow = + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + (r: InternalRow) => r + } private[this] val joinKeys = streamSideKeyGenerator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e1c05b5c56b37..34ff6bd82676e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -92,12 +92,11 @@ trait HashOuterJoin { } } - @transient private[this] lazy val resultProjection: Projection = + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = if (supportUnsafe) { - // Converted returned JoinRow into UnsafeRow UnsafeProjection.create(self.schema) } else { - ((x: InternalRow) => x).asInstanceOf[Projection] + (r: InternalRow) => r } @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) From 6294b1e3de357c94646c323eba2d4bde80971c45 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 21:35:12 -0700 Subject: [PATCH 14/14] fix projection --- .../sql/execution/joins/BroadcastNestedLoopJoin.scala | 7 +++++-- .../org/apache/spark/sql/execution/joins/HashJoin.scala | 7 +++++-- .../apache/spark/sql/execution/joins/HashOuterJoin.scala | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 384eb6b532e35..700636966f8be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,12 +47,15 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = + @transient private[this] lazy val resultProjection: Projection = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { - (r: InternalRow) => r + new Projection { + override def apply(r: InternalRow): InternalRow = r + } } + } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 645eb6e452586..ae34409bcfcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -70,12 +70,15 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 - private[this] val resultProjection: InternalRow => InternalRow = + private[this] val resultProjection: Projection = { if (supportUnsafe) { UnsafeProjection.create(self.schema) } else { - (r: InternalRow) => r + new Projection { + override def apply(r: InternalRow): InternalRow = r + } } + } private[this] val joinKeys = streamSideKeyGenerator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 34ff6bd82676e..6bf2f82954046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -92,12 +92,15 @@ trait HashOuterJoin { } } - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = + @transient private[this] lazy val resultProjection: Projection = { if (supportUnsafe) { UnsafeProjection.create(self.schema) } else { - (r: InternalRow) => r + new Projection { + override def apply(r: InternalRow): InternalRow = r + } } + } @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()