From 84d15d50693fbea35c11963484ef8cd798e7bd55 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 25 Mar 2015 20:01:21 -0700 Subject: [PATCH 1/2] minor changes --- .../spark/util/collection/CompactBuffer.scala | 16 +++++++++++++++- .../sql/execution/joins/HashedRelation.scala | 18 ++++++++++++++++-- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index 4d43d8d5cc8d8..40da093016848 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -51,7 +51,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } } - private def update(position: Int, value: T): Unit = { + def update(position: Int, value: T): Unit = { if (position < 0 || position >= curSize) { throw new IndexOutOfBoundsException } @@ -152,6 +152,20 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } private[spark] object CompactBuffer { + def constantEmpty[T: ClassTag](): CompactBuffer[T] = new CompactBuffer[T] { + override def apply(position: Int): T = + sys.error("apply() is not valid for a constant empty buffer") + + override def update(position: Int, value: T): Unit = + sys.error("update() is not valid for a constant empty buffer") + + override def += (value: T): CompactBuffer[T] = + sys.error("+= is not valid for a constant empty buffer") + + override def ++= (values: TraversableOnce[T]): CompactBuffer[T] = + sys.error("++= is not valid for a constant empty") + } + def apply[T: ClassTag](): CompactBuffer[T] = new CompactBuffer[T] def apply[T: ClassTag](value: T): CompactBuffer[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2fa1cf5add3b5..ff8e79ebe140f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -28,6 +28,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { + protected[this] val constantEmptyBuffer = CompactBuffer.constantEmpty[Row]() def get(key: Row): CompactBuffer[Row] } @@ -38,7 +39,14 @@ private[joins] sealed trait HashedRelation { private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) extends HashedRelation with Serializable { - override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + override def get(key: Row): CompactBuffer[Row] = { + val buffer = hashTable.get(key) + if (buffer == null) { + constantEmptyBuffer + } else { + buffer + } + } } @@ -48,10 +56,16 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com */ private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) extends HashedRelation with Serializable { + private val singleElementBuffer = CompactBuffer[Row](null) override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) - if (v eq null) null else CompactBuffer(v) + if (v eq null) { + constantEmptyBuffer + } else { + singleElementBuffer(0) = v + singleElementBuffer + } } def getValue(key: Row): Row = hashTable.get(key) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index aa6fb42de7f88..8011952e0d535 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -198,7 +198,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { val testDatawithNull = TestHive.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") From 645b9bee819501e7aec8d2ae1b29812a857d9fde Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 25 Mar 2015 21:17:57 -0700 Subject: [PATCH 2/2] update the code of empty check in HashedRelation related code --- .../org/apache/spark/sql/execution/joins/HashJoin.scala | 6 +++--- .../spark/sql/execution/joins/HashedRelationSuite.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 851de1685509a..92717ec88cce0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -81,17 +81,17 @@ trait HashJoin { * tuples. */ private final def fetchNext(): Boolean = { - currentHashMatches = null + currentHashMatches = CompactBuffer.constantEmpty[Row]() currentMatchPosition = -1 - while (currentHashMatches == null && streamIter.hasNext) { + while (currentHashMatches.size == 0 && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { currentHashMatches = hashedRelation.get(joinKeys.currentValue) } } - if (currentHashMatches == null) { + if (currentHashMatches.size == 0) { false } else { currentMatchPosition = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1acf..608cb51c20231 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -37,7 +37,7 @@ class HashedRelationSuite extends FunSuite { assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(Row(10)).size === 0) val data2 = CompactBuffer[Row](data(2)) data2 += data(2) @@ -52,12 +52,12 @@ class HashedRelationSuite extends FunSuite { assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(Row(10)).size === 0) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] assert(uniqHashed.getValue(data(0)) == data(0)) assert(uniqHashed.getValue(data(1)) == data(1)) assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(Row(10)) == null) + assert(uniqHashed.getValue(Row(10)).size === 0) } }