", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker)
- store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store = makeBlockManager(12000)
+ store.putSingle(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY)
// Access rdd_1_0 to ensure it's not least recently used.
assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store")
// According to the same-RDD rule, rdd_1_0 should be replaced here.
- store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY)
// rdd_1_0 should have been replaced, even it's not least recently used.
assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store")
assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store")
assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
}
+
+ test("reserve/release unroll memory") {
+ store = makeBlockManager(12000)
+ val memoryStore = store.memoryStore
+ assert(memoryStore.currentUnrollMemory === 0)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Reserve
+ memoryStore.reserveUnrollMemoryForThisThread(100)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 100)
+ memoryStore.reserveUnrollMemoryForThisThread(200)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 300)
+ memoryStore.reserveUnrollMemoryForThisThread(500)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 800)
+ memoryStore.reserveUnrollMemoryForThisThread(1000000)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted
+ // Release
+ memoryStore.releaseUnrollMemoryForThisThread(100)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 700)
+ memoryStore.releaseUnrollMemoryForThisThread(100)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 600)
+ // Reserve again
+ memoryStore.reserveUnrollMemoryForThisThread(4400)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 5000)
+ memoryStore.reserveUnrollMemoryForThisThread(20000)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted
+ // Release again
+ memoryStore.releaseUnrollMemoryForThisThread(1000)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 4000)
+ memoryStore.releaseUnrollMemoryForThisThread() // release all
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ }
+
+ /**
+ * Verify the result of MemoryStore#unrollSafely is as expected.
+ */
+ private def verifyUnroll(
+ expected: Iterator[Any],
+ result: Either[Array[Any], Iterator[Any]],
+ shouldBeArray: Boolean): Unit = {
+ val actual: Iterator[Any] = result match {
+ case Left(arr: Array[Any]) =>
+ assert(shouldBeArray, "expected iterator from unroll!")
+ arr.iterator
+ case Right(it: Iterator[Any]) =>
+ assert(!shouldBeArray, "expected array from unroll!")
+ it
+ case _ =>
+ fail("unroll returned neither an iterator nor an array...")
+ }
+ expected.zip(actual).foreach { case (e, a) =>
+ assert(e === a, "unroll did not return original values!")
+ }
+ }
+
+ test("safely unroll blocks") {
+ store = makeBlockManager(12000)
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ val memoryStore = store.memoryStore
+ val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll with all the space in the world. This should succeed and return an array.
+ var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
+ verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll with not enough space. This should succeed after kicking out someBlock1.
+ store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)
+ store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)
+ unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
+ verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(droppedBlocks.size === 1)
+ assert(droppedBlocks.head._1 === TestBlockId("someBlock1"))
+ droppedBlocks.clear()
+
+ // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 =
+ // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
+ // In the mean time, however, we kicked out someBlock2 before giving up.
+ store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)
+ unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks)
+ verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false)
+ assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ assert(droppedBlocks.size === 1)
+ assert(droppedBlocks.head._1 === TestBlockId("someBlock2"))
+ droppedBlocks.clear()
+ }
+
+ test("safely unroll blocks through putIterator") {
+ store = makeBlockManager(12000)
+ val memOnly = StorageLevel.MEMORY_ONLY
+ val memoryStore = store.memoryStore
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]]
+ def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll with plenty of space. This should succeed and cache both blocks.
+ val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
+ val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true)
+ assert(memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(result1.size > 0) // unroll was successful
+ assert(result2.size > 0)
+ assert(result1.data.isLeft) // unroll did not drop this block to disk
+ assert(result2.data.isLeft)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Re-put these two blocks so block manager knows about them too. Otherwise, block manager
+ // would not know how to drop them from memory later.
+ memoryStore.remove("b1")
+ memoryStore.remove("b2")
+ store.putIterator("b1", smallIterator, memOnly)
+ store.putIterator("b2", smallIterator, memOnly)
+
+ // Unroll with not enough space. This should succeed but kick out b1 in the process.
+ val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true)
+ assert(result3.size > 0)
+ assert(result3.data.isLeft)
+ assert(!memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ memoryStore.remove("b3")
+ store.putIterator("b3", smallIterator, memOnly)
+
+ // Unroll huge block with not enough space. This should fail and kick out b2 in the process.
+ val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, returnValues = true)
+ assert(result4.size === 0) // unroll was unsuccessful
+ assert(result4.data.isLeft)
+ assert(!memoryStore.contains("b1"))
+ assert(!memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(!memoryStore.contains("b4"))
+ assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ }
+
+ /**
+ * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK.
+ */
+ test("safely unroll blocks through putIterator (disk)") {
+ store = makeBlockManager(12000)
+ val memAndDisk = StorageLevel.MEMORY_AND_DISK
+ val memoryStore = store.memoryStore
+ val diskStore = store.diskStore
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]]
+ def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ store.putIterator("b1", smallIterator, memAndDisk)
+ store.putIterator("b2", smallIterator, memAndDisk)
+
+ // Unroll with not enough space. This should succeed but kick out b1 in the process.
+ // Memory store should contain b2 and b3, while disk store should contain only b1
+ val result3 = memoryStore.putIterator("b3", smallIterator, memAndDisk, returnValues = true)
+ assert(result3.size > 0)
+ assert(!memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(diskStore.contains("b1"))
+ assert(!diskStore.contains("b2"))
+ assert(!diskStore.contains("b3"))
+ memoryStore.remove("b3")
+ store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll huge block with not enough space. This should fail and drop the new block to disk
+ // directly in addition to kicking out b2 in the process. Memory store should contain only
+ // b3, while disk store should contain b1, b2 and b4.
+ val result4 = memoryStore.putIterator("b4", bigIterator, memAndDisk, returnValues = true)
+ assert(result4.size > 0)
+ assert(result4.data.isRight) // unroll returned bytes from disk
+ assert(!memoryStore.contains("b1"))
+ assert(!memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(!memoryStore.contains("b4"))
+ assert(diskStore.contains("b1"))
+ assert(diskStore.contains("b2"))
+ assert(!diskStore.contains("b3"))
+ assert(diskStore.contains("b4"))
+ assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ }
+
+ test("multiple unrolls by the same thread") {
+ store = makeBlockManager(12000)
+ val memOnly = StorageLevel.MEMORY_ONLY
+ val memoryStore = store.memoryStore
+ val smallList = List.fill(40)(new Array[Byte](100))
+ def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // All unroll memory used is released because unrollSafely returned an array
+ memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll memory is not released because unrollSafely returned an iterator
+ // that still depends on the underlying vector used in the process
+ memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread
+ assert(unrollMemoryAfterB3 > 0)
+
+ // The unroll memory owned by this thread builds on top of its value after the previous unrolls
+ memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread
+ assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
+
+ // ... but only to a certain extent (until we run out of free space to grant new unroll memory)
+ memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread
+ memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread
+ memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread
+ assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
+ assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
+ assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
deleted file mode 100644
index 93f0c6a8e6408..0000000000000
--- a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import scala.util.Random
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass
-import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap}
-
-class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll {
- val NORMAL_ERROR = 0.20
- val HIGH_ERROR = 0.30
-
- test("fixed size insertions") {
- testWith[Int, Long](10000, i => (i, i.toLong))
- testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
- testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass()))
- }
-
- test("variable size insertions") {
- val rand = new Random(123456789)
- def randString(minLen: Int, maxLen: Int): String = {
- "a" * (rand.nextInt(maxLen - minLen) + minLen)
- }
- testWith[Int, String](10000, i => (i, randString(0, 10)))
- testWith[Int, String](10000, i => (i, randString(0, 100)))
- testWith[Int, String](10000, i => (i, randString(90, 100)))
- }
-
- test("updates") {
- val rand = new Random(123456789)
- def randString(minLen: Int, maxLen: Int): String = {
- "a" * (rand.nextInt(maxLen - minLen) + minLen)
- }
- testWith[String, Int](10000, i => (randString(0, 10000), i))
- }
-
- def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
- val map = new SizeTrackingAppendOnlyMap[K, V]()
- for (i <- 0 until numElements) {
- val (k, v) = makeElement(i)
- map(k) = v
- expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
- }
- }
-
- def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
- val betterEstimatedSize = SizeEstimator.estimate(obj)
- assert(betterEstimatedSize * (1 - error) < estimatedSize,
- s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
- assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
- s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
- }
-}
-
-object SizeTrackingAppendOnlyMapSuite {
- // Speed test, for reproducibility of results.
- // These could be highly non-deterministic in general, however.
- // Results:
- // AppendOnlyMap: 31 ms
- // SizeTracker: 54 ms
- // SizeEstimator: 1500 ms
- def main(args: Array[String]) {
- val numElements = 100000
-
- val baseTimes = for (i <- 0 until 10) yield time {
- val map = new AppendOnlyMap[Int, LargeDummyClass]()
- for (i <- 0 until numElements) {
- map(i) = new LargeDummyClass()
- }
- }
-
- val sampledTimes = for (i <- 0 until 10) yield time {
- val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]()
- for (i <- 0 until numElements) {
- map(i) = new LargeDummyClass()
- map.estimateSize()
- }
- }
-
- val unsampledTimes = for (i <- 0 until 3) yield time {
- val map = new AppendOnlyMap[Int, LargeDummyClass]()
- for (i <- 0 until numElements) {
- map(i) = new LargeDummyClass()
- SizeEstimator.estimate(map)
- }
- }
-
- println("Base: " + baseTimes)
- println("SizeTracker (sampled): " + sampledTimes)
- println("SizeEstimator (unsampled): " + unsampledTimes)
- }
-
- def time(f: => Unit): Long = {
- val start = System.currentTimeMillis()
- f
- System.currentTimeMillis() - start
- }
-
- private class LargeDummyClass {
- val arr = new Array[Int](100)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala
new file mode 100644
index 0000000000000..6c956d93dc80d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+class CompactBufferSuite extends FunSuite {
+ test("empty buffer") {
+ val b = new CompactBuffer[Int]
+ assert(b.size === 0)
+ assert(b.iterator.toList === Nil)
+ assert(b.size === 0)
+ assert(b.iterator.toList === Nil)
+ intercept[IndexOutOfBoundsException] { b(0) }
+ intercept[IndexOutOfBoundsException] { b(1) }
+ intercept[IndexOutOfBoundsException] { b(2) }
+ intercept[IndexOutOfBoundsException] { b(-1) }
+ }
+
+ test("basic inserts") {
+ val b = new CompactBuffer[Int]
+ assert(b.size === 0)
+ assert(b.iterator.toList === Nil)
+ for (i <- 0 until 1000) {
+ b += i
+ assert(b.size === i + 1)
+ assert(b(i) === i)
+ }
+ assert(b.iterator.toList === (0 until 1000).toList)
+ assert(b.iterator.toList === (0 until 1000).toList)
+ assert(b.size === 1000)
+ }
+
+ test("adding sequences") {
+ val b = new CompactBuffer[Int]
+ assert(b.size === 0)
+ assert(b.iterator.toList === Nil)
+
+ // Add some simple lists and iterators
+ b ++= List(0)
+ assert(b.size === 1)
+ assert(b.iterator.toList === List(0))
+ b ++= Iterator(1)
+ assert(b.size === 2)
+ assert(b.iterator.toList === List(0, 1))
+ b ++= List(2)
+ assert(b.size === 3)
+ assert(b.iterator.toList === List(0, 1, 2))
+ b ++= Iterator(3, 4, 5, 6, 7, 8, 9)
+ assert(b.size === 10)
+ assert(b.iterator.toList === (0 until 10).toList)
+
+ // Add CompactBuffers
+ val b2 = new CompactBuffer[Int]
+ b2 ++= 0 until 10
+ b ++= b2
+ assert(b.iterator.toList === (1 to 2).flatMap(i => 0 until 10).toList)
+ b ++= b2
+ assert(b.iterator.toList === (1 to 3).flatMap(i => 0 until 10).toList)
+ b ++= b2
+ assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList)
+
+ // Add some small CompactBuffers as well
+ val b3 = new CompactBuffer[Int]
+ b ++= b3
+ assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList)
+ b3 += 0
+ b ++= b3
+ assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0))
+ b3 += 1
+ b ++= b3
+ assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1))
+ b3 += 2
+ b ++= b3
+ assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1, 0, 1, 2))
+ }
+
+ test("adding the same buffer to itself") {
+ val b = new CompactBuffer[Int]
+ assert(b.size === 0)
+ assert(b.iterator.toList === Nil)
+ b += 1
+ assert(b.toList === List(1))
+ for (j <- 1 until 8) {
+ b ++= b
+ assert(b.size === (1 << j))
+ assert(b.iterator.toList === (1 to (1 << j)).map(i => 1).toList)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 428822949c085..0b7ad184a46d2 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -63,12 +63,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
- map.insert(1, 10)
- map.insert(2, 20)
- map.insert(3, 30)
- map.insert(1, 100)
- map.insert(2, 200)
- map.insert(1, 1000)
+ map.insertAll(Seq(
+ (1, 10),
+ (2, 20),
+ (3, 30),
+ (1, 100),
+ (2, 200),
+ (1, 1000)))
val it = map.iterator
assert(it.hasNext)
val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
@@ -282,7 +283,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
assert(w1.hashCode === w2.hashCode)
}
- (1 to 100000).map(_.toString).foreach { i => map.insert(i, i) }
+ map.insertAll((1 to 100000).iterator.map(_.toString).map(i => (i, i)))
collisionPairs.foreach { case (w1, w2) =>
map.insert(w1, w2)
map.insert(w2, w1)
@@ -355,7 +356,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
createCombiner, mergeValue, mergeCombiners)
- (1 to 100000).foreach { i => map.insert(i, i) }
+ map.insertAll((1 to 100000).iterator.map(i => (i, i)))
map.insert(null.asInstanceOf[Int], 1)
map.insert(1, null.asInstanceOf[Int])
map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int])
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
new file mode 100644
index 0000000000000..1f33967249654
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.SizeEstimator
+
+class SizeTrackerSuite extends FunSuite {
+ val NORMAL_ERROR = 0.20
+ val HIGH_ERROR = 0.30
+
+ import SizeTrackerSuite._
+
+ test("vector fixed size insertions") {
+ testVector[Long](10000, i => i.toLong)
+ testVector[(Long, Long)](10000, i => (i.toLong, i.toLong))
+ testVector[LargeDummyClass](10000, i => new LargeDummyClass)
+ }
+
+ test("vector variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testVector[String](10000, i => randString(0, 10))
+ testVector[String](10000, i => randString(0, 100))
+ testVector[String](10000, i => randString(90, 100))
+ }
+
+ test("map fixed size insertions") {
+ testMap[Int, Long](10000, i => (i, i.toLong))
+ testMap[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
+ testMap[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass))
+ }
+
+ test("map variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testMap[Int, String](10000, i => (i, randString(0, 10)))
+ testMap[Int, String](10000, i => (i, randString(0, 100)))
+ testMap[Int, String](10000, i => (i, randString(90, 100)))
+ }
+
+ test("map updates") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testMap[String, Int](10000, i => (randString(0, 10000), i))
+ }
+
+ def testVector[T: ClassTag](numElements: Int, makeElement: Int => T) {
+ val vector = new SizeTrackingVector[T]
+ for (i <- 0 until numElements) {
+ val item = makeElement(i)
+ vector += item
+ expectWithinError(vector, vector.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def testMap[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
+ val map = new SizeTrackingAppendOnlyMap[K, V]
+ for (i <- 0 until numElements) {
+ val (k, v) = makeElement(i)
+ map(k) = v
+ expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
+ val betterEstimatedSize = SizeEstimator.estimate(obj)
+ assert(betterEstimatedSize * (1 - error) < estimatedSize,
+ s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
+ assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
+ s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
+ }
+}
+
+private object SizeTrackerSuite {
+
+ /**
+ * Run speed tests for size tracking collections.
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.size < 1) {
+ println("Usage: SizeTrackerSuite [num elements]")
+ System.exit(1)
+ }
+ val numElements = args(0).toInt
+ vectorSpeedTest(numElements)
+ mapSpeedTest(numElements)
+ }
+
+ /**
+ * Speed test for SizeTrackingVector.
+ *
+ * Results for 100000 elements (possibly non-deterministic):
+ * PrimitiveVector 15 ms
+ * SizeTracker 51 ms
+ * SizeEstimator 2000 ms
+ */
+ def vectorSpeedTest(numElements: Int): Unit = {
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val vector = new PrimitiveVector[LargeDummyClass]
+ for (i <- 0 until numElements) {
+ vector += new LargeDummyClass
+ }
+ }
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val vector = new SizeTrackingVector[LargeDummyClass]
+ for (i <- 0 until numElements) {
+ vector += new LargeDummyClass
+ vector.estimateSize()
+ }
+ }
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val vector = new PrimitiveVector[LargeDummyClass]
+ for (i <- 0 until numElements) {
+ vector += new LargeDummyClass
+ SizeEstimator.estimate(vector)
+ }
+ }
+ printSpeedTestResult("SizeTrackingVector", baseTimes, sampledTimes, unsampledTimes)
+ }
+
+ /**
+ * Speed test for SizeTrackingAppendOnlyMap.
+ *
+ * Results for 100000 elements (possibly non-deterministic):
+ * AppendOnlyMap 30 ms
+ * SizeTracker 41 ms
+ * SizeEstimator 1666 ms
+ */
+ def mapSpeedTest(numElements: Int): Unit = {
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass
+ }
+ }
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass
+ map.estimateSize()
+ }
+ }
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass
+ SizeEstimator.estimate(map)
+ }
+ }
+ printSpeedTestResult("SizeTrackingAppendOnlyMap", baseTimes, sampledTimes, unsampledTimes)
+ }
+
+ def printSpeedTestResult(
+ testName: String,
+ baseTimes: Seq[Long],
+ sampledTimes: Seq[Long],
+ unsampledTimes: Seq[Long]): Unit = {
+ println(s"Average times for $testName (ms):")
+ println(" Base - " + averageTime(baseTimes))
+ println(" SizeTracker (sampled) - " + averageTime(sampledTimes))
+ println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes))
+ println()
+ }
+
+ def time(f: => Unit): Long = {
+ val start = System.currentTimeMillis()
+ f
+ System.currentTimeMillis() - start
+ }
+
+ def averageTime(v: Seq[Long]): Long = {
+ v.sum / v.size
+ }
+
+ private class LargeDummyClass {
+ val arr = new Array[Int](100)
+ }
+}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 38830103d1e8d..33de24d1ae6d7 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then
-Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
-Dmaven.javadoc.skip=true \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\
-Dtag=$GIT_TAG -DautoVersionSubmodules=true \
--batch-mode release:prepare
@@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then
-Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Dmaven.javadoc.skip=true \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\
release:perform
cd ..
@@ -111,10 +111,10 @@ make_binary_release() {
spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
}
-make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4"
-make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0"
+make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4"
+make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0"
make_binary_release "hadoop2" \
- "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0"
+ "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0"
# Copy data
echo "Copying release tarballs"
diff --git a/dev/run-tests b/dev/run-tests
index 51e4def0f835a..98ec969dc1b37 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -65,7 +65,7 @@ echo "========================================================================="
# (either resolution or compilation) prompts the user for input either q, r,
# etc to quit or retry. This echo is there to make it not block.
if [ -n "$_RUN_SQL_TESTS" ]; then
- echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \
+ echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \
assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
else
echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \
diff --git a/dev/scalastyle b/dev/scalastyle
index a02d06912f238..d9f2b91a3a091 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,7 +17,7 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
+echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
# Check style with YARN alpha built too
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
diff --git a/docs/configuration.md b/docs/configuration.md
index cb0c65e2d2200..2e6c85cc2bcca 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -197,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful
Spark's dependencies and user dependencies. It is currently an experimental feature.
+
+ spark.python.worker.memory |
+ 512m |
+
+ Amount of memory to use per python worker process during aggregation, in the same
+ format as JVM memory strings (e.g. 512m , 2g ). If the memory
+ used during aggregation goes above this amount, it will spill the data into disks.
+ |
+
#### Shuffle Behavior
@@ -230,7 +239,7 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.memoryFraction |
- 0.3 |
+ 0.2 |
Fraction of Java heap to use for aggregation and cogroups during shuffles, if
spark.shuffle.spill is true. At any given time, the collective size of
@@ -371,13 +380,13 @@ Apart from these, the following properties are also available, and may be useful
|
spark.serializer.objectStreamReset |
- 10000 |
+ 100 |
When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
objects to prevent writing redundant data, however that stops garbage collection of those
objects. By calling 'reset' you flush that info from the serializer, and allow old
objects to be collected. To turn off this periodic reset set it to a value <= 0.
- By default it will reset the serializer every 10,000 objects.
+ By default it will reset the serializer every 100 objects.
|
@@ -471,6 +480,15 @@ Apart from these, the following properties are also available, and may be useful
increase it if you configure your own old generation size.
+
+ spark.storage.unrollFraction |
+ 0.2 |
+
+ Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory.
+ This is dynamically allocated by dropping existing blocks when there is not enough free
+ storage space to unroll the new block in its entirety.
+ |
+
spark.tachyonStore.baseDir |
System.getProperty("java.io.tmpdir") |
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 38728534a46e0..156e0aebdebe6 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD
// Define the schema using a case class.
-// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
+// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
// you can use custom classes that implement the Product interface.
case class Person(name: String, age: Int)
@@ -548,7 +548,6 @@ results = hiveContext.hql("FROM src SELECT key, value").collect()
-
# Writing Language-Integrated Relational Queries
**Language-Integrated queries are currently only supported in Scala.**
@@ -573,4 +572,200 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres
evaluated by the SQL execution engine. A full list of the functions supported can be found in the
[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD).
-
\ No newline at end of file
+
+
+## Running the Thrift JDBC server
+
+The Thrift JDBC server implemented here corresponds to the [`HiveServer2`]
+(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test
+the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive
+you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver`
+for maven).
+
+To start the JDBC server, run the following in the Spark directory:
+
+ ./sbin/start-thriftserver.sh
+
+The default port the server listens on is 10000. To listen on customized host and port, please set
+the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may
+run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can
+use beeline to test the Thrift JDBC server:
+
+ ./bin/beeline
+
+Connect to the JDBC server in beeline with:
+
+ beeline> !connect jdbc:hive2://localhost:10000
+
+Beeline will ask you for a username and password. In non-secure mode, simply enter the username on
+your machine and a blank password. For secure mode, please follow the instructions given in the
+[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients)
+
+Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
+
+You may also use the beeline script comes with Hive.
+
+### Migration Guide for Shark Users
+
+#### Reducer number
+
+In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark
+SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value
+is 200. Users may customize this property via `SET`:
+
+```
+SET spark.sql.shuffle.partitions=10;
+SELECT page, count(*) c FROM logs_last_month_cached
+GROUP BY page ORDER BY c DESC LIMIT 10;
+```
+
+You may also put this property in `hive-site.xml` to override the default value.
+
+For now, the `mapred.reduce.tasks` property is still recognized, and is converted to
+`spark.sql.shuffle.partitions` automatically.
+
+#### Caching
+
+The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no
+longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to
+let user control table caching explicitly:
+
+```
+CACHE TABLE logs_last_month;
+UNCACHE TABLE logs_last_month;
+```
+
+**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary",
+but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be
+cached, you may simply count the table immediately after executing `CACHE TABLE`:
+
+```
+CACHE TABLE logs_last_month;
+SELECT COUNT(1) FROM logs_last_month;
+```
+
+Several caching related features are not supported yet:
+
+* User defined partition level cache eviction policy
+* RDD reloading
+* In-memory cache write through policy
+
+### Compatibility with Apache Hive
+
+#### Deploying in Exising Hive Warehouses
+
+Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
+installations. You do not need to modify your existing Hive Metastore or change the data placement
+or partitioning of your tables.
+
+#### Supported Hive Features
+
+Spark SQL supports the vast majority of Hive features, such as:
+
+* Hive query statements, including:
+ * `SELECT`
+ * `GROUP BY
+ * `ORDER BY`
+ * `CLUSTER BY`
+ * `SORT BY`
+* All Hive operators, including:
+ * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc)
+ * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc)
+ * Logical operators (`AND`, `&&`, `OR`, `||`, etc)
+ * Complex type constructors
+ * Mathemtatical functions (`sign`, `ln`, `cos`, etc)
+ * String functions (`instr`, `length`, `printf`, etc)
+* User defined functions (UDF)
+* User defined aggregation functions (UDAF)
+* User defined serialization formats (SerDe's)
+* Joins
+ * `JOIN`
+ * `{LEFT|RIGHT|FULL} OUTER JOIN`
+ * `LEFT SEMI JOIN`
+ * `CROSS JOIN`
+* Unions
+* Sub queries
+ * `SELECT col FROM ( SELECT a + b AS col from t1) t2`
+* Sampling
+* Explain
+* Partitioned tables
+* All Hive DDL Functions, including:
+ * `CREATE TABLE`
+ * `CREATE TABLE AS SELECT`
+ * `ALTER TABLE`
+* Most Hive Data types, including:
+ * `TINYINT`
+ * `SMALLINT`
+ * `INT`
+ * `BIGINT`
+ * `BOOLEAN`
+ * `FLOAT`
+ * `DOUBLE`
+ * `STRING`
+ * `BINARY`
+ * `TIMESTAMP`
+ * `ARRAY<>`
+ * `MAP<>`
+ * `STRUCT<>`
+
+#### Unsupported Hive Functionality
+
+Below is a list of Hive features that we don't support yet. Most of these features are rarely used
+in Hive deployments.
+
+**Major Hive Features**
+
+* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
+ doesn't support buckets yet.
+
+**Esoteric Hive Features**
+
+* Tables with partitions using different input formats: In Spark SQL, all table partitions need to
+ have the same input format.
+* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions
+ (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple.
+* `UNIONTYPE`
+* Unique join
+* Single query multi insert
+* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at
+ the moment.
+
+**Hive Input/Output Formats**
+
+* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat.
+* Hadoop archive
+
+**Hive Optimizations**
+
+A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are
+not necessary due to Spark SQL's in-memory computational model. Others are slotted for future
+releases of Spark SQL.
+
+* Block level bitmap indexes and virtual columns (used to build indexes)
+* Automatically convert a join to map join: For joining a large table with multiple small tables,
+ Hive automatically converts the join into a map join. We are adding this auto conversion in the
+ next release.
+* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you
+ need to control the degree of parallelism post-shuffle using "SET
+ spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the
+ next release.
+* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still
+ launches tasks to compute the result.
+* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
+* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint.
+* Merge multiple small files for query results: if the result output contains multiple small files,
+ Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS
+ metadata. Spark SQL does not support that.
+
+## Running the Spark SQL CLI
+
+The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute
+queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server.
+
+To start the Spark SQL CLI, run the following in the Spark directory:
+
+ ./bin/spark-sql
+
+Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
+You may run `./bin/spark-sql --help` for a complete list of all available
+options.
diff --git a/examples/pom.xml b/examples/pom.xml
index bd1c387c2eb91..c4ed0f5a6a02b 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-examples_2.10
- examples
+ examples
jar
Spark Project Examples
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 61a6aff543aed..874b8a7959bb6 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-flume_2.10
- streaming-flume
+ streaming-flume
jar
Spark Project External Flume
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 4762c50685a93..25a5c0a4d7d77 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-kafka_2.10
- streaming-kafka
+ streaming-kafka
jar
Spark Project External Kafka
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 32c530e600ce0..f31ed655f6779 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-mqtt_2.10
- streaming-mqtt
+ streaming-mqtt
jar
Spark Project External MQTT
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 637adb0f00da0..56bb24c2a072e 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-twitter_2.10
- streaming-twitter
+ streaming-twitter
jar
Spark Project External Twitter
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index e4d758a04a4cd..54b0242c54e78 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-zeromq_2.10
- streaming-zeromq
+ streaming-zeromq
jar
Spark Project External ZeroMQ
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 7e3bcf29dcfbc..6dd52fc618b1e 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-graphx_2.10
- graphx
+ graphx
jar
Spark Project GraphX
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 92b07e2357db1..f27cf520dc9fa 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-mllib_2.10
- mllib
+ mllib
jar
Spark Project ML Library
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c44173793b39a..954621ee8b933 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
}
}
+ private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
+ require(bytes.length - offset == 8, "Wrong size byte array for Double")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ bb.getDouble
+ }
+
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values)
}
+ /**
+ * Returns an 8-byte array for the input Double.
+ *
+ * Note: we currently do not use a magic byte for double for storage efficiency.
+ * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
+ * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
+ * serialization scheme changes.
+ */
+ private[python] def serializeDouble(double: Double): Array[Byte] = {
+ val bytes = new Array[Byte](8)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(double)
+ bytes
+ }
+
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index b6e0c4a80e27b..6c7be0a4f1dcb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] (
}
}
- override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
+ override def predict(testData: RDD[Vector]): RDD[Double] = {
+ val bcModel = testData.context.broadcast(this)
+ testData.mapPartitions { iter =>
+ val model = bcModel.value
+ iter.map(model.predict)
+ }
+ }
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index de22fbb6ffc10..db425d866bbad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -165,18 +165,21 @@ class KMeans private (
val activeCenters = activeRuns.map(r => centers(r)).toArray
val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
+ val bcActiveCenters = sc.broadcast(activeCenters)
+
// Find the sum and count of points mapping to each center
val totalContribs = data.mapPartitions { points =>
- val runs = activeCenters.length
- val k = activeCenters(0).length
- val dims = activeCenters(0)(0).vector.length
+ val thisActiveCenters = bcActiveCenters.value
+ val runs = thisActiveCenters.length
+ val k = thisActiveCenters(0).length
+ val dims = thisActiveCenters(0)(0).vector.length
val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
val counts = Array.fill(runs, k)(0L)
points.foreach { point =>
(0 until runs).foreach { i =>
- val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point)
+ val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i) += cost
sums(i)(bestCenter) += point.vector
counts(i)(bestCenter) += 1
@@ -264,16 +267,17 @@ class KMeans private (
// to their squared distance from that run's current centers
var step = 0
while (step < initializationSteps) {
+ val bcCenters = data.context.broadcast(centers)
val sumCosts = data.flatMap { point =>
(0 until runs).map { r =>
- (r, KMeans.pointCost(centers(r), point))
+ (r, KMeans.pointCost(bcCenters.value(r), point))
}
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
points.flatMap { p =>
(0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
+ rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
}.map((_, p))
}
}.collect()
@@ -286,9 +290,10 @@ class KMeans private (
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
+ val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
(0 until runs).map { r =>
- ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
+ ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index fba21aefaaacd..5823cb6e52e7f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -38,7 +38,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm
- points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
+ val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+ points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
}
/** Maps given points to their cluster indices. */
@@ -51,7 +52,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
*/
def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm
- data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
+ val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+ data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
}
private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 7030eeabe400a..9fd760bf78083 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -163,6 +163,7 @@ object GradientDescent extends Logging {
// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)
+ val n = weights.size
/**
* For the first iteration, the regVal will be initialized as sum of weight squares
@@ -172,12 +173,13 @@ object GradientDescent extends Logging {
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
+ val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
- .aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ .aggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
+ val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 7bbed9c8fdbef..179cd4a3f1625 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -195,13 +195,14 @@ object LBFGS extends Logging {
override def calculate(weights: BDV[Double]) = {
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
- val localData = data
val localGradient = gradient
+ val n = weights.length
+ val bcWeights = data.context.broadcast(weights)
- val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
- features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
+ features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
new file mode 100644
index 0000000000000..7ecb409c4a91a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
+
+/**
+ * :: Experimental ::
+ * Trait for random number generators that generate i.i.d. values from a distribution.
+ */
+@Experimental
+trait DistributionGenerator extends Pseudorandom with Serializable {
+
+ /**
+ * Returns an i.i.d. sample as a Double from an underlying distribution.
+ */
+ def nextValue(): Double
+
+ /**
+ * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the
+ * class when applicable for non-locking concurrent usage.
+ */
+ def copy(): DistributionGenerator
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from U[0.0, 1.0]
+ */
+@Experimental
+class UniformGenerator extends DistributionGenerator {
+
+ // XORShiftRandom for better performance. Thread safety isn't necessary here.
+ private val random = new XORShiftRandom()
+
+ override def nextValue(): Double = {
+ random.nextDouble()
+ }
+
+ override def setSeed(seed: Long) = random.setSeed(seed)
+
+ override def copy(): UniformGenerator = new UniformGenerator()
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from the standard normal distribution.
+ */
+@Experimental
+class StandardNormalGenerator extends DistributionGenerator {
+
+ // XORShiftRandom for better performance. Thread safety isn't necessary here.
+ private val random = new XORShiftRandom()
+
+ override def nextValue(): Double = {
+ random.nextGaussian()
+ }
+
+ override def setSeed(seed: Long) = random.setSeed(seed)
+
+ override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from the Poisson distribution with the given mean.
+ *
+ * @param mean mean for the Poisson distribution.
+ */
+@Experimental
+class PoissonGenerator(val mean: Double) extends DistributionGenerator {
+
+ private var rng = new Poisson(mean, new DRand)
+
+ override def nextValue(): Double = rng.nextDouble()
+
+ override def setSeed(seed: Long) {
+ rng = new Poisson(mean, new DRand(seed.toInt))
+ }
+
+ override def copy(): PoissonGenerator = new PoissonGenerator(mean)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
new file mode 100644
index 0000000000000..d7ee2d3f46846
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
@@ -0,0 +1,473 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+/**
+ * :: Experimental ::
+ * Generator methods for creating RDDs comprised of i.i.d samples from some distribution.
+ */
+@Experimental
+object RandomRDDGenerators {
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = {
+ val uniform = new UniformGenerator()
+ randomRDD(sc, uniform, size, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = {
+ uniformRDD(sc, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0].
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = {
+ uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = {
+ val normal = new StandardNormalGenerator()
+ randomRDD(sc, normal, size, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = {
+ normalRDD(sc, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the standard normal distribution.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalRDD(sc: SparkContext, size: Long): RDD[Double] = {
+ normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonRDD(sc: SparkContext,
+ mean: Double,
+ size: Long,
+ numPartitions: Int,
+ seed: Long): RDD[Double] = {
+ val poisson = new PoissonGenerator(mean)
+ randomRDD(sc, poisson, size, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonRDD(sc: SparkContext, mean: Double, size: Long, numPartitions: Int): RDD[Double] = {
+ poissonRDD(sc, mean, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonRDD(sc: SparkContext, mean: Double, size: Long): RDD[Double] = {
+ poissonRDD(sc, mean, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples produced by generator.
+ */
+ @Experimental
+ def randomRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ size: Long,
+ numPartitions: Int,
+ seed: Long): RDD[Double] = {
+ new RandomRDD(sc, size, numPartitions, generator, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples produced by generator.
+ */
+ @Experimental
+ def randomRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ size: Long,
+ numPartitions: Int): RDD[Double] = {
+ randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples produced by generator.
+ */
+ @Experimental
+ def randomRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ size: Long): RDD[Double] = {
+ randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ // TODO Generate RDD[Vector] from multivariate distributions.
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * uniform distribution on [0.0 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ val uniform = new UniformGenerator()
+ randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * uniform distribution on [0.0 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ uniformVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * uniform distribution on [0.0 1.0].
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = {
+ uniformVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ val uniform = new StandardNormalGenerator()
+ randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ normalVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * standard normal distribution.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = {
+ normalVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonVectorRDD(sc: SparkContext,
+ mean: Double,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ val poisson = new PoissonGenerator(mean)
+ randomVectorRDD(sc, poisson, numRows, numCols, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonVectorRDD(sc: SparkContext,
+ mean: Double,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ poissonVectorRDD(sc, mean, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * Poisson distribution with the input mean.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonVectorRDD(sc: SparkContext,
+ mean: Double,
+ numRows: Long,
+ numCols: Int): RDD[Vector] = {
+ poissonVectorRDD(sc, mean, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the
+ * input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples produced by generator.
+ */
+ @Experimental
+ def randomVectorRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ new RandomVectorRDD(sc, numRows, numCols, numPartitions, generator, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the
+ * input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples produced by generator.
+ */
+ @Experimental
+ def randomVectorRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ randomVectorRDD(sc, generator, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the
+ * input DistributionGenerator.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples produced by generator.
+ */
+ @Experimental
+ def randomVectorRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ numRows: Long,
+ numCols: Int): RDD[Vector] = {
+ randomVectorRDD(sc, generator, numRows, numCols,
+ sc.defaultParallelism, Utils.random.nextLong)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
new file mode 100644
index 0000000000000..f13282d07ff92
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector}
+import org.apache.spark.mllib.random.DistributionGenerator
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+import scala.util.Random
+
+private[mllib] class RandomRDDPartition(override val index: Int,
+ val size: Int,
+ val generator: DistributionGenerator,
+ val seed: Long) extends Partition {
+
+ require(size >= 0, "Non-negative partition size required.")
+}
+
+// These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue
+private[mllib] class RandomRDD(@transient sc: SparkContext,
+ size: Long,
+ numPartitions: Int,
+ @transient rng: DistributionGenerator,
+ @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) {
+
+ require(size > 0, "Positive RDD size required.")
+ require(numPartitions > 0, "Positive number of partitions required")
+ require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue,
+ "Partition size cannot exceed Int.MaxValue")
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = {
+ val split = splitIn.asInstanceOf[RandomRDDPartition]
+ RandomRDD.getPointIterator(split)
+ }
+
+ override def getPartitions: Array[Partition] = {
+ RandomRDD.getPartitions(size, numPartitions, rng, seed)
+ }
+}
+
+private[mllib] class RandomVectorRDD(@transient sc: SparkContext,
+ size: Long,
+ vectorSize: Int,
+ numPartitions: Int,
+ @transient rng: DistributionGenerator,
+ @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) {
+
+ require(size > 0, "Positive RDD size required.")
+ require(numPartitions > 0, "Positive number of partitions required")
+ require(vectorSize > 0, "Positive vector size required.")
+ require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue,
+ "Partition size cannot exceed Int.MaxValue")
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = {
+ val split = splitIn.asInstanceOf[RandomRDDPartition]
+ RandomRDD.getVectorIterator(split, vectorSize)
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ RandomRDD.getPartitions(size, numPartitions, rng, seed)
+ }
+}
+
+private[mllib] object RandomRDD {
+
+ def getPartitions(size: Long,
+ numPartitions: Int,
+ rng: DistributionGenerator,
+ seed: Long): Array[Partition] = {
+
+ val partitions = new Array[RandomRDDPartition](numPartitions)
+ var i = 0
+ var start: Long = 0
+ var end: Long = 0
+ val random = new Random(seed)
+ while (i < numPartitions) {
+ end = ((i + 1) * size) / numPartitions
+ partitions(i) = new RandomRDDPartition(i, (end - start).toInt, rng, random.nextLong())
+ start = end
+ i += 1
+ }
+ partitions.asInstanceOf[Array[Partition]]
+ }
+
+ // The RNG has to be reset every time the iterator is requested to guarantee same data
+ // every time the content of the RDD is examined.
+ def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = {
+ val generator = partition.generator.copy()
+ generator.setSeed(partition.seed)
+ Array.fill(partition.size)(generator.nextValue()).toIterator
+ }
+
+ // The RNG has to be reset every time the iterator is requested to guarantee same data
+ // every time the content of the RDD is examined.
+ def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = {
+ val generator = partition.generator.copy()
+ generator.setSeed(partition.seed)
+ Array.fill(partition.size)(new DenseVector(
+ (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index fe41863bce985..54854252d7477 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -56,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weights
+ val bcWeights = testData.context.broadcast(localWeights)
val localIntercept = intercept
-
- testData.map(v => predictPoint(v, localWeights, localIntercept))
+ testData.mapPartitions { iter =>
+ val w = bcWeights.value
+ iter.map(v => predictPoint(v, w, localIntercept))
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala
index 88de2c82479b7..1f7de630e778c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala
@@ -122,6 +122,10 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = {
val partitioner = new HashPartitioner(input.partitions.size)
val cogrouped = new CoGroupedRDD[Long](ranks, partitioner)
- cogrouped.map { case (_, values: Seq[Seq[Double]]) => new DenseVector(values.flatten.toArray) }
+ cogrouped.map {
+ case (_, values: Array[Iterable[_]]) =>
+ val doubles = values.asInstanceOf[Array[Iterable[Double]]]
+ new DenseVector(doubles.flatten.toArray)
+ }
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index faa675b59cd50..862221d48798a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -92,8 +92,6 @@ public void runLRUsingStaticMethods() {
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
- System.out.println(numAccurate);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 642843f90204c..d94cfa2fcec81 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features)
}
}
+
+ test("double serialization") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
+ val bytes = py.serializeDouble(x)
+ val deser = py.deserializeDouble(bytes)
+ assert(x === deser)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 44b757b6a1fb7..3f6ff859374c7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object LogisticRegressionSuite {
@@ -126,3 +126,19 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LogisticRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 516895d04222d..06cdd04f5fdae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object NaiveBayesSuite {
@@ -96,3 +96,21 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 10
+ val n = 200000
+ val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map { i =>
+ LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
+ }
+ }
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = NaiveBayes.train(examples)
+ val predictions = model.predict(examples.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 886c71dde3af7..65e5df58db4c7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -17,17 +17,16 @@
package org.apache.spark.mllib.classification
-import scala.util.Random
import scala.collection.JavaConversions._
-
-import org.scalatest.FunSuite
+import scala.util.Random
import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
import org.apache.spark.SparkException
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object SVMSuite {
@@ -193,3 +192,19 @@ class SVMSuite extends FunSuite with LocalSparkContext {
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
+
+class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = SVMWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 76a3bdf9b11c8..34bc4537a7b3a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -17,14 +17,16 @@
package org.apache.spark.mllib.clustering
+import scala.util.Random
+
import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class KMeansSuite extends FunSuite with LocalSparkContext {
- import KMeans.{RANDOM, K_MEANS_PARALLEL}
+ import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
test("single cluster") {
val data = sc.parallelize(Array(
@@ -38,26 +40,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// No matter how many runs or iterations we use, we should get one cluster,
// centered at the mean of the points
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
model = KMeans.train(
- data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
}
@@ -100,26 +102,27 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.dense(1.0, 3.0, 4.0)
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.size === 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+ initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
}
@@ -145,25 +148,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+ initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
data.unpersist()
@@ -183,15 +187,15 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// it will make at least five passes, and it will give non-zero probability to each
// unselected point as long as it hasn't yet selected all of them
- var model = KMeans.train(rdd, k=5, maxIterations=1)
+ var model = KMeans.train(rdd, k = 5, maxIterations = 1)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
// Iterations of Lloyd's should not change the answer either
- model = KMeans.train(rdd, k=5, maxIterations=10)
+ model = KMeans.train(rdd, k = 5, maxIterations = 10)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
// Neither should more runs
- model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
+ model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
}
@@ -220,3 +224,22 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
}
}
}
+
+class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble)))
+ }.cache()
+ for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) {
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = KMeans.train(points, 2, 2, 1, initMode)
+ val predictions = model.predict(points).collect()
+ val cost = model.computeCost(points)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index a961f89456a18..325b817980f68 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
+import scala.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
+import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class RowMatrixSuite extends FunSuite with LocalSparkContext {
@@ -193,3 +194,27 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
}
}
}
+
+class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ var mat: RowMatrix = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ val m = 4
+ val n = 200000
+ val rows = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble())))
+ }
+ mat = new RowMatrix(rows)
+ }
+
+ test("task size should be small in svd") {
+ val svd = mat.computeSVD(1, computeU = true)
+ }
+
+ test("task size should be small in summarize") {
+ val summary = mat.computeColumnSummaryStatistics()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 951b4f7c6e6f4..dfb2eb7f0d14e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.optimization
-import scala.util.Random
import scala.collection.JavaConversions._
+import scala.util.Random
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object GradientDescentSuite {
@@ -46,7 +45,7 @@ object GradientDescentSuite {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
- val unifRand = new scala.util.Random(45)
+ val unifRand = new Random(45)
val rLogis = (0 until nPoints).map { i =>
val u = unifRand.nextDouble()
math.log(u) - math.log(1.0-u)
@@ -144,3 +143,26 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
"should be initialWeightsWithIntercept.")
}
}
+
+class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val (weights, loss) = GradientDescent.runMiniBatchSGD(
+ points,
+ new LogisticGradient,
+ new SquaredL2Updater,
+ 0.1,
+ 2,
+ 1.0,
+ 1.0,
+ Vectors.dense(new Array[Double](n)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index fe7a9033cd5f4..ff414742e8393 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.optimization
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import scala.util.Random
+
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
@@ -230,3 +231,24 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
"The weight differences between LBFGS and GD should be within 2%.")
}
}
+
+class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small") {
+ val m = 10
+ val n = 200000
+ val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble))))
+ }.cache()
+ val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
+ .setNumCorrections(1)
+ .setConvergenceTol(1e-12)
+ .setMaxNumIterations(1)
+ .setRegParam(1.0)
+ val random = new Random(0)
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val weights = lbfgs.optimize(examples, Vectors.dense(Array.fill(n)(random.nextDouble)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
new file mode 100644
index 0000000000000..974dec4c0b5ee
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.StatCounter
+
+// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
+class DistributionGeneratorSuite extends FunSuite {
+
+ def apiChecks(gen: DistributionGenerator) {
+
+ // resetting seed should generate the same sequence of random numbers
+ gen.setSeed(42L)
+ val array1 = (0 until 1000).map(_ => gen.nextValue())
+ gen.setSeed(42L)
+ val array2 = (0 until 1000).map(_ => gen.nextValue())
+ assert(array1.equals(array2))
+
+ // newInstance should contain a difference instance of the rng
+ // i.e. setting difference seeds for difference instances produces different sequences of
+ // random numbers.
+ val gen2 = gen.copy()
+ gen.setSeed(0L)
+ val array3 = (0 until 1000).map(_ => gen.nextValue())
+ gen2.setSeed(1L)
+ val array4 = (0 until 1000).map(_ => gen2.nextValue())
+ // Compare arrays instead of elements since individual elements can coincide by chance but the
+ // sequences should differ given two different seeds.
+ assert(!array3.equals(array4))
+
+ // test that setting the same seed in the copied instance produces the same sequence of numbers
+ gen.setSeed(0L)
+ val array5 = (0 until 1000).map(_ => gen.nextValue())
+ gen2.setSeed(0L)
+ val array6 = (0 until 1000).map(_ => gen2.nextValue())
+ assert(array5.equals(array6))
+ }
+
+ def distributionChecks(gen: DistributionGenerator,
+ mean: Double = 0.0,
+ stddev: Double = 1.0,
+ epsilon: Double = 0.01) {
+ for (seed <- 0 until 5) {
+ gen.setSeed(seed.toLong)
+ val sample = (0 until 100000).map { _ => gen.nextValue()}
+ val stats = new StatCounter(sample)
+ assert(math.abs(stats.mean - mean) < epsilon)
+ assert(math.abs(stats.stdev - stddev) < epsilon)
+ }
+ }
+
+ test("UniformGenerator") {
+ val uniform = new UniformGenerator()
+ apiChecks(uniform)
+ // Stddev of uniform distribution = (ub - lb) / math.sqrt(12)
+ distributionChecks(uniform, 0.5, 1 / math.sqrt(12))
+ }
+
+ test("StandardNormalGenerator") {
+ val normal = new StandardNormalGenerator()
+ apiChecks(normal)
+ distributionChecks(normal, 0.0, 1.0)
+ }
+
+ test("PoissonGenerator") {
+ // mean = 0.0 will not pass the API checks since 0.0 is always deterministically produced.
+ for (mean <- List(1.0, 5.0, 100.0)) {
+ val poisson = new PoissonGenerator(mean)
+ apiChecks(poisson)
+ distributionChecks(poisson, mean, math.sqrt(mean), 0.1)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
new file mode 100644
index 0000000000000..6aa4f803df0f7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.StatCounter
+
+/*
+ * Note: avoid including APIs that do not set the seed for the RNG in unit tests
+ * in order to guarantee deterministic behavior.
+ *
+ * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
+ */
+class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable {
+
+ def testGeneratedRDD(rdd: RDD[Double],
+ expectedSize: Long,
+ expectedNumPartitions: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double = 0.01) {
+ val stats = rdd.stats()
+ assert(expectedSize === stats.count)
+ assert(expectedNumPartitions === rdd.partitions.size)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ // assume test RDDs are small
+ def testGeneratedVectorRDD(rdd: RDD[Vector],
+ expectedRows: Long,
+ expectedColumns: Int,
+ expectedNumPartitions: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double = 0.01) {
+ assert(expectedNumPartitions === rdd.partitions.size)
+ val values = new ArrayBuffer[Double]()
+ rdd.collect.foreach { vector => {
+ assert(vector.size === expectedColumns)
+ values ++= vector.toArray
+ }}
+ assert(expectedRows === values.size / expectedColumns)
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ test("RandomRDD sizes") {
+
+ // some cases where size % numParts != 0 to test getPartitions behaves correctly
+ for ((size, numPartitions) <- List((10000, 6), (12345, 1), (1000, 101))) {
+ val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L)
+ assert(rdd.count() === size)
+ assert(rdd.partitions.size === numPartitions)
+
+ // check that partition sizes are balanced
+ val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble)
+ val partStats = new StatCounter(partSizes)
+ assert(partStats.max - partStats.min <= 1)
+ }
+
+ // size > Int.MaxValue
+ val size = Int.MaxValue.toLong * 100L
+ val numPartitions = 101
+ val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L)
+ assert(rdd.partitions.size === numPartitions)
+ val count = rdd.partitions.foldLeft(0L) { (count, part) =>
+ count + part.asInstanceOf[RandomRDDPartition].size
+ }
+ assert(count === size)
+
+ // size needs to be positive
+ intercept[IllegalArgumentException] { new RandomRDD(sc, 0, 10, new UniformGenerator, 0L) }
+
+ // numPartitions needs to be positive
+ intercept[IllegalArgumentException] { new RandomRDD(sc, 100, 0, new UniformGenerator, 0L) }
+
+ // partition size needs to be <= Int.MaxValue
+ intercept[IllegalArgumentException] {
+ new RandomRDD(sc, Int.MaxValue.toLong * 100L, 99, new UniformGenerator, 0L)
+ }
+ }
+
+ test("randomRDD for different distributions") {
+ val size = 100000L
+ val numPartitions = 10
+ val poissonMean = 100.0
+
+ for (seed <- 0 until 5) {
+ val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed)
+ testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12))
+
+ val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed)
+ testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0)
+
+ val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed)
+ testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1)
+ }
+
+ // mock distribution to check that partitions have unique seeds
+ val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L)
+ assert(random.collect.size === random.collect.distinct.size)
+ }
+
+ test("randomVectorRDD for different distributions") {
+ val rows = 1000L
+ val cols = 100
+ val parts = 10
+ val poissonMean = 100.0
+
+ for (seed <- 0 until 5) {
+ val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed)
+ testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12))
+
+ val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed)
+ testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0)
+
+ val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed)
+ testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1)
+ }
+ }
+}
+
+private[random] class MockDistro extends DistributionGenerator {
+
+ var seed = 0L
+
+ // This allows us to check that each partition has a different seed
+ override def nextValue(): Double = seed.toDouble
+
+ override def setSeed(seed: Long) = this.seed = seed
+
+ override def copy(): MockDistro = new MockDistro
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index bfa42959c8ead..7aa96421aed87 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
+import scala.util.Random
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -113,3 +116,19 @@ class LassoSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LassoWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 7aaad7d7a3e39..4f89112b650c5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
+import scala.util.Random
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
@@ -122,3 +125,19 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
}
}
+
+class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LinearRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67768e17fbe6d..727bbd051ff15 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,11 +17,14 @@
package org.apache.spark.mllib.regression
-import org.scalatest.FunSuite
+import scala.util.Random
import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
@@ -73,3 +76,19 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
}
+
+class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = RidgeRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
new file mode 100644
index 0000000000000..5e9101cdd3804
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.{Suite, BeforeAndAfterAll}
+
+import org.apache.spark.{SparkConf, SparkContext}
+
+trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
+ @transient var sc: SparkContext = _
+
+ override def beforeAll() {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("test-cluster")
+ .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
+ sc = new SparkContext(conf)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (sc != null) {
+ sc.stop()
+ }
+ super.afterAll()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 0d4868f3d9e42..7857d9e5ee5c4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -20,13 +20,16 @@ package org.apache.spark.mllib.util
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
override def beforeAll() {
- sc = new SparkContext("local", "test")
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ sc = new SparkContext(conf)
super.beforeAll()
}
diff --git a/pom.xml b/pom.xml
index 4e2d64a833640..3e9d388180d8e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -95,6 +95,7 @@
sql/catalyst
sql/core
sql/hive
+ sql/hive-thriftserver
repl
assembly
external/twitter
@@ -252,9 +253,9 @@
3.3.2
- commons-codec
- commons-codec
- 1.5
+ commons-codec
+ commons-codec
+ 1.5
com.google.code.findbugs
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index e9220db6b1f9a..5ff88f0dd1cac 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -31,7 +31,6 @@ import com.typesafe.tools.mima.core._
* MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap")
*/
object MimaExcludes {
-
def excludes(version: String) =
version match {
case v if v.startsWith("1.1") =>
@@ -62,6 +61,15 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.storage.MemoryStore.Entry")
) ++
+ Seq(
+ // Renamed putValues -> putArray + putIterator
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.MemoryStore.putValues"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.DiskStore.putValues"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.TachyonStore.putValues")
+ ) ++
Seq(
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this")
) ++
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 52fd61d2234b7..4b3e05b36f6bf 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConversions._
import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
+import sbtunidoc.Plugin.genjavadocSettings
import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings}
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
@@ -29,11 +30,11 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming,
- streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) =
- Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql",
- "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
- "streaming-zeromq").map(ProjectRef(buildLocation, _))
+ val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql,
+ streaming, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) =
+ Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
+ "spark", "sql", "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt",
+ "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _))
val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) =
Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl")
@@ -99,7 +100,7 @@ object SparkBuild extends PomBuild {
Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
case Some(v) =>
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1)))
- case _ =>
+ case _ =>
}
override val userPropertiesMap = System.getProperties.toMap
@@ -107,7 +108,7 @@ object SparkBuild extends PomBuild {
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
- lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ Seq (
+ lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ genjavadocSettings ++ Seq (
javaHome := Properties.envOrNone("JAVA_HOME").map(file),
incOptions := incOptions.value.withNameHashing(true),
retrieveManaged := true,
@@ -157,7 +158,7 @@ object SparkBuild extends PomBuild {
/* Enable Mima for all projects except spark, hive, catalyst, sql and repl */
// TODO: Add Sql to mima checks
- allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)).
+ allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)).
foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x))
/* Enable Assembly for all assembly projects */
diff --git a/python/epydoc.conf b/python/epydoc.conf
index b73860bad8263..51c0faf359939 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -35,4 +35,4 @@ private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests
pyspark.rddsampler pyspark.daemon pyspark.mllib._common
- pyspark.mllib.tests
+ pyspark.mllib.tests pyspark.shuffle
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 024fb881877c9..e8ac9895cf54a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -37,6 +37,15 @@
from py4j.java_collections import ListConverter
+# These are special default configs for PySpark, they will overwrite
+# the default ones for Spark if they are not configured by user.
+DEFAULT_CONFIGS = {
+ "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
+ "spark.serializer.objectStreamReset": 100,
+ "spark.rdd.compress": True,
+}
+
+
class SparkContext(object):
"""
Main entry point for Spark functionality. A SparkContext represents the
@@ -101,7 +110,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
- self._conf.setIfMissing("spark.rdd.compress", "true")
+
# Set any parameters passed directly to us on the conf
if master:
self._conf.setMaster(master)
@@ -112,6 +121,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
if environment:
for key, value in environment.iteritems():
self._conf.setExecutorEnv(key, value)
+ for key, value in DEFAULT_CONFIGS.items():
+ self._conf.setIfMissing(key, value)
# Check that we have at least the required parameters
if not self._conf.contains("spark.master"):
@@ -216,6 +227,13 @@ def setSystemProperty(cls, key, value):
SparkContext._ensure_initialized()
SparkContext._jvm.java.lang.System.setProperty(key, value)
+ @property
+ def version(self):
+ """
+ The version of Spark on which this application is running.
+ """
+ return self._jsc.version()
+
@property
def defaultParallelism(self):
"""
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 43b491a9716fc..8e3ad6b783b6c 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -72,9 +72,9 @@
# Python interpreter must agree on what endian the machine is.
-DENSE_VECTOR_MAGIC = 1
+DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2
-DENSE_MATRIX_MAGIC = 3
+DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4
@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
return ar.copy()
+def _serialize_double(d):
+ """
+ Serialize a double (float or numpy.float64) into a mutually understood format.
+ """
+ if type(d) == float or type(d) == float64:
+ d = float64(d)
+ ba = bytearray(8)
+ _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64)
+ return ba
+ else:
+ raise TypeError("_serialize_double called on non-float input")
+
+
def _serialize_double_vector(v):
- """Serialize a double vector into a mutually understood format.
+ """
+ Serialize a double vector into a mutually understood format.
+
+ Note: we currently do not use a magic byte for double for storage
+ efficiency. This should be reconsidered when we add Ser/De for other
+ 8-byte types (e.g. Long), for safety. The corresponding deserializer,
+ _deserialize_double, needs to be modified as well if the serialization
+ scheme changes.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v):
return ba
+def _deserialize_double(ba, offset=0):
+ """Deserialize a double from a mutually understood format.
+
+ >>> import sys
+ >>> _deserialize_double(_serialize_double(123.0)) == 123.0
+ True
+ >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0
+ True
+ >>> x = sys.float_info.max
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x
+ True
+ >>> y = float64(sys.float_info.max)
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == y
+ True
+ """
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
+ if len(ba) - offset != 8:
+ raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
+ return struct.unpack("d", ba[offset:])[0]
+
+
def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format.
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index a38dd0b9237c5..b84d976114f0d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,9 +39,11 @@
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler
+from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
+ get_used_memory
from py4j.java_collections import ListConverter, MapConverter
@@ -197,6 +199,22 @@ def _replaceRoot(self, value):
self._sink(1)
+def _parse_memory(s):
+ """
+ Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
+ return the value in MB
+
+ >>> _parse_memory("256m")
+ 256
+ >>> _parse_memory("2g")
+ 2048
+ """
+ units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
+ if s[-1] not in units:
+ raise ValueError("invalid format: " + s)
+ return int(float(s[:-1]) * units[s[-1].lower()])
+
+
class RDD(object):
"""
@@ -393,7 +411,7 @@ def sample(self, withReplacement, fraction, seed=None):
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
- assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
# this is ported from scala/spark/RDD.scala
@@ -1207,20 +1225,49 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- # Transferring O(n) objects to Java is too expensive. Instead, we'll
- # form the hash buckets in Python, transferring O(numPartitions) objects
- # to Java. Each object is a (splitNumber, [objects]) pair.
+ # Transferring O(n) objects to Java is too expensive.
+ # Instead, we'll form the hash buckets in Python,
+ # transferring O(numPartitions) objects to Java.
+ # Each object is a (splitNumber, [objects]) pair.
+ # In order to avoid too huge objects, the objects are
+ # grouped into chunks.
outputSerializer = self.ctx._unbatched_serializer
+ limit = (_parse_memory(self.ctx._conf.get(
+ "spark.python.worker.memory", "512m")) / 2)
+
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
+ c, batch = 0, min(10 * numPartitions, 1000)
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
+ c += 1
+
+ # check used memory and avg size of chunk of objects
+ if (c % 1000 == 0 and get_used_memory() > limit
+ or c > batch):
+ n, size = len(buckets), 0
+ for split in buckets.keys():
+ yield pack_long(split)
+ d = outputSerializer.dumps(buckets[split])
+ del buckets[split]
+ yield d
+ size += len(d)
+
+ avg = (size / n) >> 20
+ # let 1M < avg < 10M
+ if avg < 1:
+ batch *= 1.5
+ elif avg > 10:
+ batch = max(batch / 1.5, 1)
+ c = 0
+
for (split, items) in buckets.iteritems():
yield pack_long(split)
yield outputSerializer.dumps(items)
+
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
@@ -1230,8 +1277,8 @@ def add_shuffle_key(split, iterator):
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique, even if
- # partitionFunc is a lambda:
+ # This is required so that id(partitionFunc) remains unique,
+ # even if partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
return rdd
@@ -1265,26 +1312,28 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ serializer = self.ctx.serializer
+ spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
+ == 'true')
+ memory = _parse_memory(self.ctx._conf.get(
+ "spark.python.worker.memory", "512m"))
+ agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
def combineLocally(iterator):
- combiners = {}
- for x in iterator:
- (k, v) = x
- if k not in combiners:
- combiners[k] = createCombiner(v)
- else:
- combiners[k] = mergeValue(combiners[k], v)
- return combiners.iteritems()
+ merger = ExternalMerger(agg, memory * 0.9, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeValues(iterator)
+ return merger.iteritems()
+
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
- combiners = {}
- for (k, v) in iterator:
- if k not in combiners:
- combiners[k] = v
- else:
- combiners[k] = mergeCombiners(combiners[k], v)
- return combiners.iteritems()
+ merger = ExternalMerger(agg, memory, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeCombiners(iterator)
+ return merger.iteritems()
+
return shuffled.mapPartitions(_mergeCombiners)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
@@ -1343,7 +1392,8 @@ def mergeValue(xs, x):
return xs
def mergeCombiners(a, b):
- return a + b
+ a.extend(b)
+ return a
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: ResultIterable(x))
@@ -1406,6 +1456,27 @@ def cogroup(self, other, numPartitions=None):
"""
return python_cogroup((self, other), numPartitions)
+ def sampleByKey(self, withReplacement, fractions, seed=None):
+ """
+ Return a subset of this RDD sampled by key (via stratified sampling).
+ Create a sample of this RDD using variable sampling rates for
+ different keys as specified by fractions, a key to sampling rate map.
+
+ >>> fractions = {"a": 0.2, "b": 0.1}
+ >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000)))
+ >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect())
+ >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150
+ True
+ >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0
+ True
+ >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0
+ True
+ """
+ for fraction in fractions.values():
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ return self.mapPartitionsWithIndex( \
+ RDDStratifiedSampler(withReplacement, fractions, seed).func, True)
+
def subtractByKey(self, other, numPartitions=None):
"""
Return each (key, value) pair in C{self} that has no pair with matching
@@ -1616,7 +1687,6 @@ def _jrdd(self):
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
- class_tag = self._prev_jrdd.classTag()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
@@ -1625,8 +1695,7 @@ def _jrdd(self):
bytearray(pickled_command),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator,
- class_tag)
+ broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 7ff1c316c7623..2df000fdb08ca 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,8 +19,8 @@
import random
-class RDDSampler(object):
- def __init__(self, withReplacement, fraction, seed=None):
+class RDDSamplerBase(object):
+ def __init__(self, withReplacement, seed=None):
try:
import numpy
self._use_numpy = True
@@ -32,7 +32,6 @@ def __init__(self, withReplacement, fraction, seed=None):
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
- self._fraction = fraction
self._random = None
self._split = None
self._rand_initialized = False
@@ -94,6 +93,12 @@ def shuffle(self, vals):
else:
self._random.shuffle(vals, self._random.random)
+
+class RDDSampler(RDDSamplerBase):
+ def __init__(self, withReplacement, fraction, seed=None):
+ RDDSamplerBase.__init__(self, withReplacement, seed)
+ self._fraction = fraction
+
def func(self, split, iterator):
if self._withReplacement:
for obj in iterator:
@@ -107,3 +112,22 @@ def func(self, split, iterator):
for obj in iterator:
if self.getUniformSample(split) <= self._fraction:
yield obj
+
+class RDDStratifiedSampler(RDDSamplerBase):
+ def __init__(self, withReplacement, fractions, seed=None):
+ RDDSamplerBase.__init__(self, withReplacement, seed)
+ self._fractions = fractions
+
+ def func(self, split, iterator):
+ if self._withReplacement:
+ for key, val in iterator:
+ # For large datasets, the expected number of occurrences of each element in
+ # a sample with replacement is Poisson(frac). We use that to get a count for
+ # each element.
+ count = self.getPoissonSample(split, mean=self._fractions[key])
+ for _ in range(0, count):
+ yield key, val
+ else:
+ for key, val in iterator:
+ if self.getUniformSample(split) <= self._fractions[key]:
+ yield key, val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9be78b39fbc21..03b31ae9624c2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -193,7 +193,7 @@ def load_stream(self, stream):
return chain.from_iterable(self._load_stream_without_unbatching(stream))
def _load_stream_without_unbatching(self, stream):
- return self.serializer.load_stream(stream)
+ return self.serializer.load_stream(stream)
def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
@@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer):
loads = marshal.loads
+class AutoSerializer(FramedSerializer):
+ """
+ Choose marshal or cPickle as serialization protocol autumatically
+ """
+ def __init__(self):
+ FramedSerializer.__init__(self)
+ self._type = None
+
+ def dumps(self, obj):
+ if self._type is not None:
+ return 'P' + cPickle.dumps(obj, -1)
+ try:
+ return 'M' + marshal.dumps(obj)
+ except Exception:
+ self._type = 'P'
+ return 'P' + cPickle.dumps(obj, -1)
+
+ def loads(self, obj):
+ _type = obj[0]
+ if _type == 'M':
+ return marshal.loads(obj[1:])
+ elif _type == 'P':
+ return cPickle.loads(obj[1:])
+ else:
+ raise ValueError("invalid sevialization type: %s" % _type)
+
+
class UTF8Deserializer(Serializer):
"""
Deserializes streams written by String.getBytes.
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
new file mode 100644
index 0000000000000..e3923d1c36c57
--- /dev/null
+++ b/python/pyspark/shuffle.py
@@ -0,0 +1,439 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import platform
+import shutil
+import warnings
+import gc
+
+from pyspark.serializers import BatchedSerializer, PickleSerializer
+
+try:
+ import psutil
+
+ def get_used_memory():
+ """ Return the used memory in MB """
+ process = psutil.Process(os.getpid())
+ if hasattr(process, "memory_info"):
+ info = process.memory_info()
+ else:
+ info = process.get_memory_info()
+ return info.rss >> 20
+except ImportError:
+
+ def get_used_memory():
+ """ Return the used memory in MB """
+ if platform.system() == 'Linux':
+ for line in open('/proc/self/status'):
+ if line.startswith('VmRSS:'):
+ return int(line.split()[1]) >> 10
+ else:
+ warnings.warn("Please install psutil to have better "
+ "support with spilling")
+ if platform.system() == "Darwin":
+ import resource
+ rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ return rss >> 20
+ # TODO: support windows
+ return 0
+
+
+class Aggregator(object):
+
+ """
+ Aggregator has tree functions to merge values into combiner.
+
+ createCombiner: (value) -> combiner
+ mergeValue: (combine, value) -> combiner
+ mergeCombiners: (combiner, combiner) -> combiner
+ """
+
+ def __init__(self, createCombiner, mergeValue, mergeCombiners):
+ self.createCombiner = createCombiner
+ self.mergeValue = mergeValue
+ self.mergeCombiners = mergeCombiners
+
+
+class SimpleAggregator(Aggregator):
+
+ """
+ SimpleAggregator is useful for the cases that combiners have
+ same type with values
+ """
+
+ def __init__(self, combiner):
+ Aggregator.__init__(self, lambda x: x, combiner, combiner)
+
+
+class Merger(object):
+
+ """
+ Merge shuffled data together by aggregator
+ """
+
+ def __init__(self, aggregator):
+ self.agg = aggregator
+
+ def mergeValues(self, iterator):
+ """ Combine the items by creator and combiner """
+ raise NotImplementedError
+
+ def mergeCombiners(self, iterator):
+ """ Merge the combined items by mergeCombiner """
+ raise NotImplementedError
+
+ def iteritems(self):
+ """ Return the merged items ad iterator """
+ raise NotImplementedError
+
+
+class InMemoryMerger(Merger):
+
+ """
+ In memory merger based on in-memory dict.
+ """
+
+ def __init__(self, aggregator):
+ Merger.__init__(self, aggregator)
+ self.data = {}
+
+ def mergeValues(self, iterator):
+ """ Combine the items by creator and combiner """
+ # speed up attributes lookup
+ d, creator = self.data, self.agg.createCombiner
+ comb = self.agg.mergeValue
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else creator(v)
+
+ def mergeCombiners(self, iterator):
+ """ Merge the combined items by mergeCombiner """
+ # speed up attributes lookup
+ d, comb = self.data, self.agg.mergeCombiners
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else v
+
+ def iteritems(self):
+ """ Return the merged items ad iterator """
+ return self.data.iteritems()
+
+
+class ExternalMerger(Merger):
+
+ """
+ External merger will dump the aggregated data into disks when
+ memory usage goes above the limit, then merge them together.
+
+ This class works as follows:
+
+ - It repeatedly combine the items and save them in one dict in
+ memory.
+
+ - When the used memory goes above memory limit, it will split
+ the combined data into partitions by hash code, dump them
+ into disk, one file per partition.
+
+ - Then it goes through the rest of the iterator, combine items
+ into different dict by hash. Until the used memory goes over
+ memory limit, it dump all the dicts into disks, one file per
+ dict. Repeat this again until combine all the items.
+
+ - Before return any items, it will load each partition and
+ combine them seperately. Yield them before loading next
+ partition.
+
+ - During loading a partition, if the memory goes over limit,
+ it will partition the loaded data and dump them into disks
+ and load them partition by partition again.
+
+ `data` and `pdata` are used to hold the merged items in memory.
+ At first, all the data are merged into `data`. Once the used
+ memory goes over limit, the items in `data` are dumped indo
+ disks, `data` will be cleared, all rest of items will be merged
+ into `pdata` and then dumped into disks. Before returning, all
+ the items in `pdata` will be dumped into disks.
+
+ Finally, if any items were spilled into disks, each partition
+ will be merged into `data` and be yielded, then cleared.
+
+ >>> agg = SimpleAggregator(lambda x, y: x + y)
+ >>> merger = ExternalMerger(agg, 10)
+ >>> N = 10000
+ >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+ >>> assert merger.spills > 0
+ >>> sum(v for k,v in merger.iteritems())
+ 499950000
+
+ >>> merger = ExternalMerger(agg, 10)
+ >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+ >>> assert merger.spills > 0
+ >>> sum(v for k,v in merger.iteritems())
+ 499950000
+ """
+
+ # the max total partitions created recursively
+ MAX_TOTAL_PARTITIONS = 4096
+
+ def __init__(self, aggregator, memory_limit=512, serializer=None,
+ localdirs=None, scale=1, partitions=59, batch=1000):
+ Merger.__init__(self, aggregator)
+ self.memory_limit = memory_limit
+ # default serializer is only used for tests
+ self.serializer = serializer or \
+ BatchedSerializer(PickleSerializer(), 1024)
+ self.localdirs = localdirs or self._get_dirs()
+ # number of partitions when spill data into disks
+ self.partitions = partitions
+ # check the memory after # of items merged
+ self.batch = batch
+ # scale is used to scale down the hash of key for recursive hash map
+ self.scale = scale
+ # unpartitioned merged data
+ self.data = {}
+ # partitioned merged data, list of dicts
+ self.pdata = []
+ # number of chunks dumped into disks
+ self.spills = 0
+ # randomize the hash of key, id(o) is the address of o (aligned by 8)
+ self._seed = id(self) + 7
+
+ def _get_dirs(self):
+ """ Get all the directories """
+ path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
+ dirs = path.split(",")
+ return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
+ for d in dirs]
+
+ def _get_spill_dir(self, n):
+ """ Choose one directory for spill by number n """
+ return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
+
+ def _next_limit(self):
+ """
+ Return the next memory limit. If the memory is not released
+ after spilling, it will dump the data only when the used memory
+ starts to increase.
+ """
+ return max(self.memory_limit, get_used_memory() * 1.05)
+
+ def mergeValues(self, iterator):
+ """ Combine the items by creator and combiner """
+ iterator = iter(iterator)
+ # speedup attribute lookup
+ creator, comb = self.agg.createCombiner, self.agg.mergeValue
+ d, c, batch = self.data, 0, self.batch
+
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else creator(v)
+
+ c += 1
+ if c % batch == 0 and get_used_memory() > self.memory_limit:
+ self._spill()
+ self._partitioned_mergeValues(iterator, self._next_limit())
+ break
+
+ def _partition(self, key):
+ """ Return the partition for key """
+ return hash((key, self._seed)) % self.partitions
+
+ def _partitioned_mergeValues(self, iterator, limit=0):
+ """ Partition the items by key, then combine them """
+ # speedup attribute lookup
+ creator, comb = self.agg.createCombiner, self.agg.mergeValue
+ c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
+
+ for k, v in iterator:
+ d = pdata[hfun(k)]
+ d[k] = comb(d[k], v) if k in d else creator(v)
+ if not limit:
+ continue
+
+ c += 1
+ if c % batch == 0 and get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+
+ def mergeCombiners(self, iterator, check=True):
+ """ Merge (K,V) pair by mergeCombiner """
+ iterator = iter(iterator)
+ # speedup attribute lookup
+ d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
+ c = 0
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else v
+ if not check:
+ continue
+
+ c += 1
+ if c % batch == 0 and get_used_memory() > self.memory_limit:
+ self._spill()
+ self._partitioned_mergeCombiners(iterator, self._next_limit())
+ break
+
+ def _partitioned_mergeCombiners(self, iterator, limit=0):
+ """ Partition the items by key, then merge them """
+ comb, pdata = self.agg.mergeCombiners, self.pdata
+ c, hfun = 0, self._partition
+ for k, v in iterator:
+ d = pdata[hfun(k)]
+ d[k] = comb(d[k], v) if k in d else v
+ if not limit:
+ continue
+
+ c += 1
+ if c % self.batch == 0 and get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+
+ def _spill(self):
+ """
+ dump already partitioned data into disks.
+
+ It will dump the data in batch for better performance.
+ """
+ path = self._get_spill_dir(self.spills)
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ if not self.pdata:
+ # The data has not been partitioned, it will iterator the
+ # dataset once, write them into different files, has no
+ # additional memory. It only called when the memory goes
+ # above limit at the first time.
+
+ # open all the files for writing
+ streams = [open(os.path.join(path, str(i)), 'w')
+ for i in range(self.partitions)]
+
+ for k, v in self.data.iteritems():
+ h = self._partition(k)
+ # put one item in batch, make it compatitable with load_stream
+ # it will increase the memory if dump them in batch
+ self.serializer.dump_stream([(k, v)], streams[h])
+
+ for s in streams:
+ s.close()
+
+ self.data.clear()
+ self.pdata = [{} for i in range(self.partitions)]
+
+ else:
+ for i in range(self.partitions):
+ p = os.path.join(path, str(i))
+ with open(p, "w") as f:
+ # dump items in batch
+ self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.pdata[i].clear()
+
+ self.spills += 1
+ gc.collect() # release the memory as much as possible
+
+ def iteritems(self):
+ """ Return all merged items as iterator """
+ if not self.pdata and not self.spills:
+ return self.data.iteritems()
+ return self._external_items()
+
+ def _external_items(self):
+ """ Return all partitioned items as iterator """
+ assert not self.data
+ if any(self.pdata):
+ self._spill()
+ hard_limit = self._next_limit()
+
+ try:
+ for i in range(self.partitions):
+ self.data = {}
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(i))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)),
+ False)
+
+ # limit the total partitions
+ if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+ and j < self.spills - 1
+ and get_used_memory() > hard_limit):
+ self.data.clear() # will read from disk again
+ gc.collect() # release the memory as much as possible
+ for v in self._recursive_merged_items(i):
+ yield v
+ return
+
+ for v in self.data.iteritems():
+ yield v
+ self.data.clear()
+ gc.collect()
+
+ # remove the merged partition
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ os.remove(os.path.join(path, str(i)))
+
+ finally:
+ self._cleanup()
+
+ def _cleanup(self):
+ """ Clean up all the files in disks """
+ for d in self.localdirs:
+ shutil.rmtree(d, True)
+
+ def _recursive_merged_items(self, start):
+ """
+ merge the partitioned items and return the as iterator
+
+ If one partition can not be fit in memory, then them will be
+ partitioned and merged recursively.
+ """
+ # make sure all the data are dumps into disks.
+ assert not self.data
+ if any(self.pdata):
+ self._spill()
+ assert self.spills > 0
+
+ for i in range(start, self.partitions):
+ subdirs = [os.path.join(d, "parts", str(i))
+ for d in self.localdirs]
+ m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
+ subdirs, self.scale * self.partitions)
+ m.pdata = [{} for _ in range(self.partitions)]
+ limit = self._next_limit()
+
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(i))
+ m._partitioned_mergeCombiners(
+ self.serializer.load_stream(open(p)))
+
+ if get_used_memory() > limit:
+ m._spill()
+ limit = self._next_limit()
+
+ for v in m._external_items():
+ yield v
+
+ # remove the merged partition
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ os.remove(os.path.join(path, str(i)))
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 9c5ecd0bb02ab..8ba51461d106d 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -34,6 +34,7 @@
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
try:
@@ -47,6 +48,62 @@
SPARK_HOME = os.environ["SPARK_HOME"]
+class TestMerger(unittest.TestCase):
+
+ def setUp(self):
+ self.N = 1 << 16
+ self.l = [i for i in xrange(self.N)]
+ self.data = zip(self.l, self.l)
+ self.agg = Aggregator(lambda x: [x],
+ lambda x, y: x.append(y) or x,
+ lambda x, y: x.extend(y) or x)
+
+ def test_in_memory(self):
+ m = InMemoryMerger(self.agg)
+ m.mergeValues(self.data)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = InMemoryMerger(self.agg)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ def test_small_dataset(self):
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeValues(self.data)
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ def test_medium_dataset(self):
+ m = ExternalMerger(self.agg, 10)
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)) * 3)
+
+ def test_huge_dataset(self):
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+ self.N * 10)
+ m._cleanup()
+
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):
@@ -169,6 +226,15 @@ def test_transforming_cartesian_result(self):
cart = rdd1.cartesian(rdd2)
result = cart.map(lambda (x, y): x + y).collect()
+ def test_transforming_pickle_file(self):
+ # Regression test for SPARK-2601
+ data = self.sc.parallelize(["Hello", "World!"])
+ tempFile = tempfile.NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsPickleFile(tempFile.name)
+ pickled_file = self.sc.pickleFile(tempFile.name)
+ pickled_file.map(lambda x: x).collect()
+
def test_cartesian_on_textfile(self):
# Regression test for
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
diff --git a/python/run-tests b/python/run-tests
index 9282aa47e8375..29f755fc0dcd3 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -61,6 +61,7 @@ run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
unset PYSPARK_DOC_TEST
+run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
run_test "pyspark/mllib/_common.py"
run_test "pyspark/mllib/classification.py"
diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh
new file mode 100755
index 0000000000000..8398e6f19b511
--- /dev/null
+++ b/sbin/start-thriftserver.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Shell script for starting the Spark SQL Thrift server
+
+# Enter posix mode for bash
+set -o posix
+
+# Figure out where Spark is installed
+FWDIR="$(cd `dirname $0`/..; pwd)"
+
+if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
+ echo "Usage: ./sbin/start-thriftserver [options]"
+ $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ exit 0
+fi
+
+CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
+exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 6decde3fcd62d..531bfddbf237b 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -32,7 +32,7 @@
Spark Project Catalyst
http://spark.apache.org/
- catalyst
+ catalyst
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 9887856b9c1c6..47c7ad076ad07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -50,6 +50,7 @@ trait HiveTypeCoercion {
StringToIntegralCasts ::
FunctionArgumentConversion ::
CastNulls ::
+ Division ::
Nil
/**
@@ -246,6 +247,8 @@ trait HiveTypeCoercion {
// No need to change other EqualTo operators as that actually makes sense for boolean types.
case e: EqualTo => e
+ // No need to change the EqualNullSafe operators, too
+ case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison
if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
@@ -315,6 +318,23 @@ trait HiveTypeCoercion {
}
}
+ /**
+ * Hive only performs integral division with the DIV operator. The arguments to / are always
+ * converted to fractional types.
+ */
+ object Division extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ // Decimal and Double remain the same
+ case d: Divide if d.dataType == DoubleType => d
+ case d: Divide if d.dataType == DecimalType => d
+
+ case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
+ }
+ }
+
/**
* Ensures that NullType gets casted to some other types under certain circumstances.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 15c98efbcabcf..5c8c810d9135a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -77,6 +77,7 @@ package object dsl {
def > (other: Expression) = GreaterThan(expr, other)
def >= (other: Expression) = GreaterThanOrEqual(expr, other)
def === (other: Expression) = EqualTo(expr, other)
+ def <=> (other: Expression) = EqualNullSafe(expr, other)
def !== (other: Expression) = Not(EqualTo(expr, other))
def in(list: Expression*) = In(expr, list)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b63406b94a4a3..06b94a98d3cd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -153,6 +153,22 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
}
}
+case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
+ def symbol = "<=>"
+ override def nullable = false
+ override def eval(input: Row): Any = {
+ val l = left.eval(input)
+ val r = right.eval(input)
+ if (l == null && r == null) {
+ true
+ } else if (l == null || r == null) {
+ false
+ } else {
+ l == r
+ }
+ }
+}
+
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<"
override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c65987b7120b2..5f86d6047cb9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -153,6 +153,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
+ case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
+ case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index 1d5f033f0d274..a357c6ffb8977 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command {
*/
case class SetCommand(key: Option[String], value: Option[String]) extends Command {
override def output = Seq(
- BoundReference(0, AttributeReference("key", StringType, nullable = false)()),
- BoundReference(1, AttributeReference("value", StringType, nullable = false)()))
+ BoundReference(1, AttributeReference("", StringType, nullable = false)()))
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index c3f5c26fdbe59..58f8c341e6676 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -451,11 +451,13 @@ class ExpressionEvaluationSuite extends FunSuite {
}
test("BinaryComparison") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)
+ val c5 = 'a.int.at(4)
+ val c6 = 'a.int.at(5)
checkEvaluation(LessThan(c1, c4), null, row)
checkEvaluation(LessThan(c1, c2), true, row)
@@ -469,6 +471,12 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(c1 >= c2, false, row)
checkEvaluation(c1 === c2, false, row)
checkEvaluation(c1 !== c2, true, row)
+ checkEvaluation(c4 <=> c1, false, row)
+ checkEvaluation(c1 <=> c4, false, row)
+ checkEvaluation(c4 <=> c6, true, row)
+ checkEvaluation(c3 <=> c5, true, row)
+ checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row)
+ checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row)
}
test("StringComparison") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index d607eed1bea89..0a27cce337482 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest {
Literal(10) as Symbol("2*3+4"),
Literal(14) as Symbol("2*(3+4)"))
.where(Literal(true))
- .groupBy(Literal(3))(Literal(3) as Symbol("9/3"))
+ .groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3"))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index c309c43804d97..3a038a2db6173 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -32,7 +32,7 @@
Spark Project SQL
http://spark.apache.org/
- sql
+ sql
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 2b787e14f3f15..41920c00b5a2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -30,12 +30,13 @@ import scala.collection.JavaConverters._
* SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads).
*/
trait SQLConf {
+ import SQLConf._
/** ************************ Spark SQL Params/Hints ******************* */
// TODO: refactor so that these hints accessors don't pollute the name space of SQLContext?
/** Number of partitions to use for shuffle operators. */
- private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
+ private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
@@ -43,11 +44,10 @@ trait SQLConf {
* effectively disables auto conversion.
* Hive setting: hive.auto.convert.join.noconditionaltask.size.
*/
- private[spark] def autoConvertJoinSize: Int =
- get("spark.sql.auto.convert.join.size", "10000").toInt
+ private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt
/** A comma-separated list of table names marked to be broadcasted during joins. */
- private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "")
+ private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "")
/** ********************** SQLConf functionality methods ************ */
@@ -61,7 +61,7 @@ trait SQLConf {
def set(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
- require(value != null, s"value cannot be null for ${key}")
+ require(value != null, s"value cannot be null for $key")
settings.put(key, value)
}
@@ -90,3 +90,13 @@ trait SQLConf {
}
}
+
+object SQLConf {
+ val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
+ val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
+ val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
+
+ object Deprecated {
+ val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 98d2f89c8ae71..9293239131d52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.execution
+import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLConf, SQLContext}
trait Command {
/**
@@ -44,28 +45,53 @@ trait Command {
case class SetCommand(
key: Option[String], value: Option[String], output: Seq[Attribute])(
@transient context: SQLContext)
- extends LeafNode with Command {
+ extends LeafNode with Command with Logging {
- override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match {
+ override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match {
// Set value for key k.
case (Some(k), Some(v)) =>
- context.set(k, v)
- Array(k -> v)
+ if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
+ logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
+ context.set(SQLConf.SHUFFLE_PARTITIONS, v)
+ Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")
+ } else {
+ context.set(k, v)
+ Array(s"$k=$v")
+ }
// Query the value bound to key k.
case (Some(k), _) =>
- Array(k -> context.getOption(k).getOrElse(""))
+ // TODO (lian) This is just a workaround to make the Simba ODBC driver work.
+ // Should remove this once we get the ODBC driver updated.
+ if (k == "-v") {
+ val hiveJars = Seq(
+ "hive-exec-0.12.0.jar",
+ "hive-service-0.12.0.jar",
+ "hive-common-0.12.0.jar",
+ "hive-hwi-0.12.0.jar",
+ "hive-0.12.0.jar").mkString(":")
+
+ Array(
+ "system:java.class.path=" + hiveJars,
+ "system:sun.java.command=shark.SharkServer2")
+ }
+ else {
+ Array(s"$k=${context.getOption(k).getOrElse("")}")
+ }
// Query all key-value pairs that are set in the SQLConf of the context.
case (None, None) =>
- context.getAll
+ context.getAll.map { case (k, v) =>
+ s"$k=$v"
+ }
case _ =>
throw new IllegalArgumentException()
}
def execute(): RDD[Row] = {
- val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) }
+ val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) }
context.sparkContext.parallelize(rows, 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 08293f7f0ca30..1a58d73d9e7f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest {
assert(get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
- sql("set mapred.reduce.tasks=20")
- assert(get("mapred.reduce.tasks", "0") == "20")
- sql("set mapred.reduce.tasks = 40")
- assert(get("mapred.reduce.tasks", "0") == "40")
+ sql("set some.property=20")
+ assert(get("some.property", "0") == "20")
+ sql("set some.property = 40")
+ assert(get("some.property", "0") == "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
@@ -70,4 +70,9 @@ class SQLConfSuite extends QueryTest {
clear()
}
+ test("deprecated property") {
+ clear()
+ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
+ assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c07753c40b656..63c559bf4d9ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -424,25 +424,25 @@ class SQLQuerySuite extends QueryTest {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Seq(Seq(testKey, testVal))
+ Seq(Seq(s"$testKey=$testVal"))
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Seq(testKey, testVal),
- Seq(testKey + testKey, testVal + testVal))
+ Seq(s"$testKey=$testVal"),
+ Seq(s"${testKey + testKey}=${testVal + testVal}"))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Seq(Seq(testKey, testVal))
+ Seq(Seq(s"$testKey=$testVal"))
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Seq(Seq(nonexistentKey, ""))
+ Seq(Seq(s"$nonexistentKey="))
)
clear()
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
new file mode 100644
index 0000000000000..7fac90fdc596d
--- /dev/null
+++ b/sql/hive-thriftserver/pom.xml
@@ -0,0 +1,82 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent
+ 1.1.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-hive-thriftserver_2.10
+ jar
+ Spark Project Hive
+ http://spark.apache.org/
+
+ hive-thriftserver
+
+
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${project.version}
+
+
+ org.spark-project.hive
+ hive-cli
+ ${hive.version}
+
+
+ org.spark-project.hive
+ hive-jdbc
+ ${hive.version}
+
+
+ org.spark-project.hive
+ hive-beeline
+ ${hive.version}
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-deploy-plugin
+
+ true
+
+
+
+
+
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
new file mode 100644
index 0000000000000..ddbc2a79fb512
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService
+import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+
+/**
+ * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a
+ * `HiveThriftServer2` thrift server.
+ */
+private[hive] object HiveThriftServer2 extends Logging {
+ var LOG = LogFactory.getLog(classOf[HiveServer2])
+
+ def main(args: Array[String]) {
+ val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2")
+
+ if (!optionsProcessor.process(args)) {
+ logger.warn("Error starting HiveThriftServer2 with given arguments")
+ System.exit(-1)
+ }
+
+ val ss = new SessionState(new HiveConf(classOf[SessionState]))
+
+ // Set all properties specified via command line.
+ val hiveConf: HiveConf = ss.getConf
+ hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) =>
+ logger.debug(s"HiveConf var: $k=$v")
+ }
+
+ SessionState.start(ss)
+
+ logger.info("Starting SparkContext")
+ SparkSQLEnv.init()
+ SessionState.start(ss)
+
+ Runtime.getRuntime.addShutdownHook(
+ new Thread() {
+ override def run() {
+ SparkSQLEnv.sparkContext.stop()
+ }
+ }
+ )
+
+ try {
+ val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
+ server.init(hiveConf)
+ server.start()
+ logger.info("HiveThriftServer2 started")
+ } catch {
+ case e: Exception =>
+ logger.error("Error starting HiveThriftServer2", e)
+ System.exit(-1)
+ }
+ }
+}
+
+private[hive] class HiveThriftServer2(hiveContext: HiveContext)
+ extends HiveServer2
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ val sparkSqlCliService = new SparkSQLCLIService(hiveContext)
+ setSuperField(this, "cliService", sparkSqlCliService)
+ addService(sparkSqlCliService)
+
+ val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService)
+ setSuperField(this, "thriftCLIService", thriftCliService)
+ addService(thriftCliService)
+
+ initCompositeService(hiveConf)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
new file mode 100644
index 0000000000000..599294dfbb7d7
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+private[hive] object ReflectionUtils {
+ def setSuperField(obj : Object, fieldName: String, fieldValue: Object) {
+ setAncestorField(obj, 1, fieldName, fieldValue)
+ }
+
+ def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) {
+ val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next()
+ val field = ancestor.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.set(obj, fieldValue)
+ }
+
+ def getSuperField[T](obj: AnyRef, fieldName: String): T = {
+ getAncestorField[T](obj, 1, fieldName)
+ }
+
+ def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = {
+ val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next()
+ val field = ancestor.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.get(clazz).asInstanceOf[T]
+ }
+
+ def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = {
+ invoke(clazz, null, methodName, args: _*)
+ }
+
+ def invoke(
+ clazz: Class[_],
+ obj: AnyRef,
+ methodName: String,
+ args: (Class[_], AnyRef)*): AnyRef = {
+
+ val (types, values) = args.unzip
+ val method = clazz.getDeclaredMethod(methodName, types: _*)
+ method.setAccessible(true)
+ method.invoke(obj, values.toSeq: _*)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
new file mode 100755
index 0000000000000..27268ecb923e9
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.io._
+import java.util.{ArrayList => JArrayList}
+
+import jline.{ConsoleReader, History}
+import org.apache.commons.lang.StringUtils
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils}
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.thrift.transport.TSocket
+
+import org.apache.spark.sql.Logging
+
+private[hive] object SparkSQLCLIDriver {
+ private var prompt = "spark-sql"
+ private var continuedPrompt = "".padTo(prompt.length, ' ')
+ private var transport:TSocket = _
+
+ installSignalHandler()
+
+ /**
+ * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(),
+ * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while
+ * a command is being processed by the current thread.
+ */
+ def installSignalHandler() {
+ HiveInterruptUtils.add(new HiveInterruptCallback {
+ override def interrupt() {
+ // Handle remote execution mode
+ if (SparkSQLEnv.sparkContext != null) {
+ SparkSQLEnv.sparkContext.cancelAllJobs()
+ } else {
+ if (transport != null) {
+ // Force closing of TCP connection upon session termination
+ transport.getSocket.close()
+ }
+ }
+ }
+ })
+ }
+
+ def main(args: Array[String]) {
+ val oproc = new OptionsProcessor()
+ if (!oproc.process_stage1(args)) {
+ System.exit(1)
+ }
+
+ // NOTE: It is critical to do this here so that log4j is reinitialized
+ // before any of the other core hive classes are loaded
+ var logInitFailed = false
+ var logInitDetailMessage: String = null
+ try {
+ logInitDetailMessage = LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException =>
+ logInitFailed = true
+ logInitDetailMessage = e.getMessage
+ }
+
+ val sessionState = new CliSessionState(new HiveConf(classOf[SessionState]))
+
+ sessionState.in = System.in
+ try {
+ sessionState.out = new PrintStream(System.out, true, "UTF-8")
+ sessionState.info = new PrintStream(System.err, true, "UTF-8")
+ sessionState.err = new PrintStream(System.err, true, "UTF-8")
+ } catch {
+ case e: UnsupportedEncodingException => System.exit(3)
+ }
+
+ if (!oproc.process_stage2(sessionState)) {
+ System.exit(2)
+ }
+
+ if (!sessionState.getIsSilent) {
+ if (logInitFailed) System.err.println(logInitDetailMessage)
+ else SessionState.getConsole.printInfo(logInitDetailMessage)
+ }
+
+ // Set all properties specified via command line.
+ val conf: HiveConf = sessionState.getConf
+ sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] =>
+ conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
+ sessionState.getOverriddenConfigurations.put(
+ item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
+ }
+
+ SessionState.start(sessionState)
+
+ // Clean up after we exit
+ Runtime.getRuntime.addShutdownHook(
+ new Thread() {
+ override def run() {
+ SparkSQLEnv.stop()
+ }
+ }
+ )
+
+ // "-h" option has been passed, so connect to Hive thrift server.
+ if (sessionState.getHost != null) {
+ sessionState.connect()
+ if (sessionState.isRemoteMode) {
+ prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt
+ continuedPrompt = "".padTo(prompt.length, ' ')
+ }
+ }
+
+ if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) {
+ // Hadoop-20 and above - we need to augment classpath using hiveconf
+ // components.
+ // See also: code in ExecDriver.java
+ var loader = conf.getClassLoader
+ val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
+ if (StringUtils.isNotBlank(auxJars)) {
+ loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ","))
+ }
+ conf.setClassLoader(loader)
+ Thread.currentThread().setContextClassLoader(loader)
+ }
+
+ val cli = new SparkSQLCLIDriver
+ cli.setHiveVariables(oproc.getHiveVariables)
+
+ // TODO work around for set the log output to console, because the HiveContext
+ // will set the output into an invalid buffer.
+ sessionState.in = System.in
+ try {
+ sessionState.out = new PrintStream(System.out, true, "UTF-8")
+ sessionState.info = new PrintStream(System.err, true, "UTF-8")
+ sessionState.err = new PrintStream(System.err, true, "UTF-8")
+ } catch {
+ case e: UnsupportedEncodingException => System.exit(3)
+ }
+
+ // Execute -i init files (always in silent mode)
+ cli.processInitFiles(sessionState)
+
+ if (sessionState.execString != null) {
+ System.exit(cli.processLine(sessionState.execString))
+ }
+
+ try {
+ if (sessionState.fileName != null) {
+ System.exit(cli.processFile(sessionState.fileName))
+ }
+ } catch {
+ case e: FileNotFoundException =>
+ System.err.println(s"Could not open input file for reading. (${e.getMessage})")
+ System.exit(3)
+ }
+
+ val reader = new ConsoleReader()
+ reader.setBellEnabled(false)
+ // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true)))
+ CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e))
+
+ val historyDirectory = System.getProperty("user.home")
+
+ try {
+ if (new File(historyDirectory).exists()) {
+ val historyFile = historyDirectory + File.separator + ".hivehistory"
+ reader.setHistory(new History(new File(historyFile)))
+ } else {
+ System.err.println("WARNING: Directory for Hive history file: " + historyDirectory +
+ " does not exist. History will not be available during this session.")
+ }
+ } catch {
+ case e: Exception =>
+ System.err.println("WARNING: Encountered an error while trying to initialize Hive's " +
+ "history file. History will not be available during this session.")
+ System.err.println(e.getMessage)
+ }
+
+ val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport")
+ clientTransportTSocketField.setAccessible(true)
+
+ transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket]
+
+ var ret = 0
+ var prefix = ""
+ val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb",
+ classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState)
+
+ def promptWithCurrentDB = s"$prompt$currentDB"
+ def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic(
+ classOf[CliDriver], "spacesForString", classOf[String] -> currentDB)
+
+ var currentPrompt = promptWithCurrentDB
+ var line = reader.readLine(currentPrompt + "> ")
+
+ while (line != null) {
+ if (prefix.nonEmpty) {
+ prefix += '\n'
+ }
+
+ if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) {
+ line = prefix + line
+ ret = cli.processLine(line, true)
+ prefix = ""
+ currentPrompt = promptWithCurrentDB
+ } else {
+ prefix = prefix + line
+ currentPrompt = continuedPromptWithDBSpaces
+ }
+
+ line = reader.readLine(currentPrompt + "> ")
+ }
+
+ sessionState.close()
+
+ System.exit(ret)
+ }
+}
+
+private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
+ private val sessionState = SessionState.get().asInstanceOf[CliSessionState]
+
+ private val LOG = LogFactory.getLog("CliDriver")
+
+ private val console = new SessionState.LogHelper(LOG)
+
+ private val conf: Configuration =
+ if (sessionState != null) sessionState.getConf else new Configuration()
+
+ // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver
+ // because the Hive unit tests do not go through the main() code path.
+ if (!sessionState.isRemoteMode) {
+ SparkSQLEnv.init()
+ }
+
+ override def processCmd(cmd: String): Int = {
+ val cmd_trimmed: String = cmd.trim()
+ val tokens: Array[String] = cmd_trimmed.split("\\s+")
+ val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
+ if (cmd_trimmed.toLowerCase.equals("quit") ||
+ cmd_trimmed.toLowerCase.equals("exit") ||
+ tokens(0).equalsIgnoreCase("source") ||
+ cmd_trimmed.startsWith("!") ||
+ tokens(0).toLowerCase.equals("list") ||
+ sessionState.isRemoteMode) {
+ val start = System.currentTimeMillis()
+ super.processCmd(cmd)
+ val end = System.currentTimeMillis()
+ val timeTaken: Double = (end - start) / 1000.0
+ console.printInfo(s"Time taken: $timeTaken seconds")
+ 0
+ } else {
+ var ret = 0
+ val hconf = conf.asInstanceOf[HiveConf]
+ val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf)
+
+ if (proc != null) {
+ if (proc.isInstanceOf[Driver]) {
+ val driver = new SparkSQLDriver
+
+ driver.init()
+ val out = sessionState.out
+ val start:Long = System.currentTimeMillis()
+ if (sessionState.getIsVerbose) {
+ out.println(cmd)
+ }
+
+ ret = driver.run(cmd).getResponseCode
+ if (ret != 0) {
+ driver.close()
+ return ret
+ }
+
+ val res = new JArrayList[String]()
+
+ if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) {
+ // Print the column names.
+ Option(driver.getSchema.getFieldSchemas).map { fields =>
+ out.println(fields.map(_.getName).mkString("\t"))
+ }
+ }
+
+ try {
+ while (!out.checkError() && driver.getResults(res)) {
+ res.foreach(out.println)
+ res.clear()
+ }
+ } catch {
+ case e:IOException =>
+ console.printError(
+ s"""Failed with exception ${e.getClass.getName}: ${e.getMessage}
+ |${org.apache.hadoop.util.StringUtils.stringifyException(e)}
+ """.stripMargin)
+ ret = 1
+ }
+
+ val cret = driver.close()
+ if (ret == 0) {
+ ret = cret
+ }
+
+ val end = System.currentTimeMillis()
+ if (end > start) {
+ val timeTaken:Double = (end - start) / 1000.0
+ console.printInfo(s"Time taken: $timeTaken seconds", null)
+ }
+
+ // Destroy the driver to release all the locks.
+ driver.destroy()
+ } else {
+ if (sessionState.getIsVerbose) {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ }
+ ret = proc.run(cmd_1).getResponseCode
+ }
+ }
+ ret
+ }
+ }
+}
+
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
new file mode 100644
index 0000000000000..42cbf363b274f
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.io.IOException
+import java.util.{List => JList}
+import javax.security.auth.login.LoginException
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.hive.service.Service.STATE
+import org.apache.hive.service.auth.HiveAuthFactory
+import org.apache.hive.service.cli.CLIService
+import org.apache.hive.service.{AbstractService, Service, ServiceException}
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+
+private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
+ extends CLIService
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext)
+ setSuperField(this, "sessionManager", sparkSqlSessionManager)
+ addService(sparkSqlSessionManager)
+
+ try {
+ HiveAuthFactory.loginFromKeytab(hiveConf)
+ val serverUserName = ShimLoader.getHadoopShims
+ .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf))
+ setSuperField(this, "serverUserName", serverUserName)
+ } catch {
+ case e @ (_: IOException | _: LoginException) =>
+ throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
+ }
+
+ initCompositeService(hiveConf)
+ }
+}
+
+private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
+ def initCompositeService(hiveConf: HiveConf) {
+ // Emulating `CompositeService.init(hiveConf)`
+ val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
+ serviceList.foreach(_.init(hiveConf))
+
+ // Emulating `AbstractService.init(hiveConf)`
+ invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
+ setAncestorField(this, 3, "hiveConf", hiveConf)
+ invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
+ getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
new file mode 100644
index 0000000000000..5202aa9903e03
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.commons.lang.exception.ExceptionUtils
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
+
+private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext)
+ extends Driver with Logging {
+
+ private var tableSchema: Schema = _
+ private var hiveResponse: Seq[String] = _
+
+ override def init(): Unit = {
+ }
+
+ private def getResultSetSchema(query: context.QueryExecution): Schema = {
+ val analyzed = query.analyzed
+ logger.debug(s"Result Schema: ${analyzed.output}")
+ if (analyzed.output.size == 0) {
+ new Schema(new FieldSchema("Response code", "string", "") :: Nil, null)
+ } else {
+ val fieldSchemas = analyzed.output.map { attr =>
+ new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
+ }
+
+ new Schema(fieldSchemas, null)
+ }
+ }
+
+ override def run(command: String): CommandProcessorResponse = {
+ val execution = context.executePlan(context.hql(command).logicalPlan)
+
+ // TODO unify the error code
+ try {
+ hiveResponse = execution.stringResult()
+ tableSchema = getResultSetSchema(execution)
+ new CommandProcessorResponse(0)
+ } catch {
+ case cause: Throwable =>
+ logger.error(s"Failed in [$command]", cause)
+ new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null)
+ }
+ }
+
+ override def close(): Int = {
+ hiveResponse = null
+ tableSchema = null
+ 0
+ }
+
+ override def getSchema: Schema = tableSchema
+
+ override def getResults(res: JArrayList[String]): Boolean = {
+ if (hiveResponse == null) {
+ false
+ } else {
+ res.addAll(hiveResponse)
+ hiveResponse = null
+ true
+ }
+ }
+
+ override def destroy() {
+ super.destroy()
+ hiveResponse = null
+ tableSchema = null
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
new file mode 100644
index 0000000000000..451c3bd7b9352
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import org.apache.hadoop.hive.ql.session.SessionState
+
+import org.apache.spark.scheduler.{SplitInfo, StatsReportListener}
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+/** A singleton object for the master program. The slaves should not access this. */
+private[hive] object SparkSQLEnv extends Logging {
+ logger.debug("Initializing SparkSQLEnv")
+
+ var hiveContext: HiveContext = _
+ var sparkContext: SparkContext = _
+
+ def init() {
+ if (hiveContext == null) {
+ sparkContext = new SparkContext(new SparkConf()
+ .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
+
+ sparkContext.addSparkListener(new StatsReportListener())
+
+ hiveContext = new HiveContext(sparkContext) {
+ @transient override lazy val sessionState = SessionState.get()
+ @transient override lazy val hiveconf = sessionState.getConf
+ }
+ }
+ }
+
+ /** Cleans up and shuts down the Spark SQL environments. */
+ def stop() {
+ logger.debug("Shutting down Spark SQL Environment")
+ // Stop the SparkContext
+ if (SparkSQLEnv.sparkContext != null) {
+ sparkContext.stop()
+ sparkContext = null
+ hiveContext = null
+ }
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
new file mode 100644
index 0000000000000..6b3275b4eaf04
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.concurrent.Executors
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hive.service.cli.session.SessionManager
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
+
+private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
+ extends SessionManager
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
+ setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
+ getAncestorField[Log](this, 3, "LOG").info(
+ s"HiveServer2: Async execution pool size $backgroundPoolSize")
+
+ val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
+ setSuperField(this, "operationManager", sparkSqlOperationManager)
+ addService(sparkSqlOperationManager)
+
+ initCompositeService(hiveConf)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
new file mode 100644
index 0000000000000..a4e1f3e762e89
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver.server
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.math.{random, round}
+
+import java.sql.Timestamp
+import java.util.{Map => JMap}
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.metastore.api.FieldSchema
+import org.apache.hive.service.cli._
+import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
+import org.apache.hive.service.cli.session.HiveSession
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils
+import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
+import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow}
+
+/**
+ * Executes queries using Spark SQL, and maintains a list of handles to active queries.
+ */
+class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging {
+ val handleToOperation = ReflectionUtils
+ .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
+
+ override def newExecuteStatementOperation(
+ parentSession: HiveSession,
+ statement: String,
+ confOverlay: JMap[String, String],
+ async: Boolean): ExecuteStatementOperation = synchronized {
+
+ val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) {
+ private var result: SchemaRDD = _
+ private var iter: Iterator[SparkRow] = _
+ private var dataTypes: Array[DataType] = _
+
+ def close(): Unit = {
+ // RDDs will be cleaned automatically upon garbage collection.
+ logger.debug("CLOSING")
+ }
+
+ def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = {
+ if (!iter.hasNext) {
+ new RowSet()
+ } else {
+ val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No.
+ var curRow = 0
+ var rowSet = new ArrayBuffer[Row](maxRows)
+
+ while (curRow < maxRows && iter.hasNext) {
+ val sparkRow = iter.next()
+ val row = new Row()
+ var curCol = 0
+
+ while (curCol < sparkRow.length) {
+ dataTypes(curCol) match {
+ case StringType =>
+ row.addString(sparkRow(curCol).asInstanceOf[String])
+ case IntegerType =>
+ row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
+ case BooleanType =>
+ row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
+ case DoubleType =>
+ row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
+ case FloatType =>
+ row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
+ case DecimalType =>
+ val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
+ row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
+ case LongType =>
+ row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
+ case ByteType =>
+ row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
+ case ShortType =>
+ row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
+ case TimestampType =>
+ row.addColumnValue(
+ ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
+ case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+ val hiveString = result
+ .queryExecution
+ .asInstanceOf[HiveContext#QueryExecution]
+ .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
+ row.addColumnValue(ColumnValue.stringValue(hiveString))
+ }
+ curCol += 1
+ }
+ rowSet += row
+ curRow += 1
+ }
+ new RowSet(rowSet, 0)
+ }
+ }
+
+ def getResultSetSchema: TableSchema = {
+ logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}")
+ if (result.queryExecution.analyzed.output.size == 0) {
+ new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
+ } else {
+ val schema = result.queryExecution.analyzed.output.map { attr =>
+ new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
+ }
+ new TableSchema(schema)
+ }
+ }
+
+ def run(): Unit = {
+ logger.info(s"Running query '$statement'")
+ setState(OperationState.RUNNING)
+ try {
+ result = hiveContext.hql(statement)
+ logger.debug(result.queryExecution.toString())
+ val groupId = round(random * 1000000).toString
+ hiveContext.sparkContext.setJobGroup(groupId, statement)
+ iter = result.queryExecution.toRdd.toLocalIterator
+ dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
+ setHasResultSet(true)
+ } catch {
+ // Actually do need to catch Throwable as some failures don't inherit from Exception and
+ // HiveServer will silently swallow them.
+ case e: Throwable =>
+ logger.error("Error executing query:",e)
+ throw new HiveSQLException(e.toString)
+ }
+ setState(OperationState.FINISHED)
+ }
+ }
+
+ handleToOperation.put(operation.getHandle, operation)
+ operation
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
new file mode 100644
index 0000000000000..850f8014b6f05
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
@@ -0,0 +1,5 @@
+238val_238
+86val_86
+311val_311
+27val_27
+165val_165
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
new file mode 100644
index 0000000000000..69f19f826a802
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.{BufferedReader, InputStreamReader, PrintWriter}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils {
+ val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli")
+ val METASTORE_PATH = TestUtils.getMetastorePath("cli")
+
+ override def beforeAll() {
+ val pb = new ProcessBuilder(
+ "../../bin/spark-sql",
+ "--master",
+ "local",
+ "--hiveconf",
+ s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
+ "--hiveconf",
+ "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH)
+
+ process = pb.start()
+ outputWriter = new PrintWriter(process.getOutputStream, true)
+ inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
+ errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
+ waitForOutput(inputReader, "spark-sql>")
+ }
+
+ override def afterAll() {
+ process.destroy()
+ process.waitFor()
+ }
+
+ test("simple commands") {
+ val dataFilePath = getDataFile("data/files/small_kv.txt")
+ executeQuery("create table hive_test1(key int, val string);")
+ executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;")
+ executeQuery("cache table hive_test1", "Time taken")
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
new file mode 100644
index 0000000000000..fe3403b3292ec
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent._
+
+import java.io.{BufferedReader, InputStreamReader}
+import java.net.ServerSocket
+import java.sql.{Connection, DriverManager, Statement}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+
+/**
+ * Test for the HiveThriftServer2 using JDBC.
+ */
+class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging {
+
+ val WAREHOUSE_PATH = getTempFilePath("warehouse")
+ val METASTORE_PATH = getTempFilePath("metastore")
+
+ val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver"
+ val TABLE = "test"
+ val HOST = "localhost"
+ val PORT = {
+ // Let the system to choose a random available port to avoid collision with other parallel
+ // builds.
+ val socket = new ServerSocket(0)
+ val port = socket.getLocalPort
+ socket.close()
+ port
+ }
+
+ // If verbose is true, the test program will print all outputs coming from the Hive Thrift server.
+ val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean
+
+ Class.forName(DRIVER_NAME)
+
+ override def beforeAll() { launchServer() }
+
+ override def afterAll() { stopServer() }
+
+ private def launchServer(args: Seq[String] = Seq.empty) {
+ // Forking a new process to start the Hive Thrift server. The reason to do this is it is
+ // hard to clean up Hive resources entirely, so we just start a new process and kill
+ // that process for cleanup.
+ val defaultArgs = Seq(
+ "../../sbin/start-thriftserver.sh",
+ "--master local",
+ "--hiveconf",
+ "hive.root.logger=INFO,console",
+ "--hiveconf",
+ s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
+ "--hiveconf",
+ s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH")
+ val pb = new ProcessBuilder(defaultArgs ++ args)
+ val environment = pb.environment()
+ environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString)
+ environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST)
+ process = pb.start()
+ inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
+ errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
+ waitForOutput(inputReader, "ThriftBinaryCLIService listening on")
+
+ // Spawn a thread to read the output from the forked process.
+ // Note that this is necessary since in some configurations, log4j could be blocked
+ // if its output to stderr are not read, and eventually blocking the entire test suite.
+ future {
+ while (true) {
+ val stdout = readFrom(inputReader)
+ val stderr = readFrom(errorReader)
+ if (VERBOSE && stdout.length > 0) {
+ println(stdout)
+ }
+ if (VERBOSE && stderr.length > 0) {
+ println(stderr)
+ }
+ Thread.sleep(50)
+ }
+ }
+ }
+
+ private def stopServer() {
+ process.destroy()
+ process.waitFor()
+ }
+
+ test("test query execution against a Hive Thrift server") {
+ Thread.sleep(5 * 1000)
+ val dataFilePath = getDataFile("data/files/small_kv.txt")
+ val stmt = createStatement()
+ stmt.execute("DROP TABLE IF EXISTS test")
+ stmt.execute("DROP TABLE IF EXISTS test_cached")
+ stmt.execute("CREATE TABLE test(key int, val string)")
+ stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
+ stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
+ stmt.execute("CACHE TABLE test_cached")
+
+ var rs = stmt.executeQuery("select count(*) from test")
+ rs.next()
+ assert(rs.getInt(1) === 5)
+
+ rs = stmt.executeQuery("select count(*) from test_cached")
+ rs.next()
+ assert(rs.getInt(1) === 4)
+
+ stmt.close()
+ }
+
+ def getConnection: Connection = {
+ val connectURI = s"jdbc:hive2://localhost:$PORT/"
+ DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
+ }
+
+ def createStatement(): Statement = getConnection.createStatement()
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
new file mode 100644
index 0000000000000..bb2242618fbef
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.{BufferedReader, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import org.apache.hadoop.hive.common.LogUtils
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+
+object TestUtils {
+ val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss")
+
+ def getWarehousePath(prefix: String): String = {
+ System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" +
+ timestamp.format(new Date)
+ }
+
+ def getMetastorePath(prefix: String): String = {
+ System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" +
+ timestamp.format(new Date)
+ }
+
+ // Dummy function for initialize the log4j properties.
+ def init() { }
+
+ // initialize log4j
+ try {
+ LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException => // Ignore the error.
+ }
+}
+
+trait TestUtils {
+ var process : Process = null
+ var outputWriter : PrintWriter = null
+ var inputReader : BufferedReader = null
+ var errorReader : BufferedReader = null
+
+ def executeQuery(
+ cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = {
+ println("Executing: " + cmd + ", expecting output: " + outputMessage)
+ outputWriter.write(cmd + "\n")
+ outputWriter.flush()
+ waitForQuery(timeout, outputMessage)
+ }
+
+ protected def waitForQuery(timeout: Long, message: String): String = {
+ if (waitForOutput(errorReader, message, timeout)) {
+ Thread.sleep(500)
+ readOutput()
+ } else {
+ assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput())
+ null
+ }
+ }
+
+ // Wait for the specified str to appear in the output.
+ protected def waitForOutput(
+ reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = {
+ val startTime = System.currentTimeMillis
+ var out = ""
+ while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) {
+ out += readFrom(reader)
+ }
+ out.contains(str)
+ }
+
+ // Read stdout output and filter out garbage collection messages.
+ protected def readOutput(): String = {
+ val output = readFrom(inputReader)
+ // Remove GC Messages
+ val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC"))
+ .mkString("\n")
+ filteredOutput
+ }
+
+ protected def readFrom(reader: BufferedReader): String = {
+ var out = ""
+ var c = 0
+ while (reader.ready) {
+ c = reader.read()
+ out += c.asInstanceOf[Char]
+ }
+ out
+ }
+
+ protected def getDataFile(name: String) = {
+ Thread.currentThread().getContextClassLoader.getResource(name)
+ }
+}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 8b451973a47a1..c69e93ba2b9ba 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -196,7 +196,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Hive returns the results of describe as plain text. Comments with multiple lines
// introduce extra lines in the Hive results, which make the result comparison fail.
- "describe_comment_indent"
+ "describe_comment_indent",
+
+ // Limit clause without a ordering, which causes failure.
+ "orc_predicate_pushdown"
)
/**
@@ -503,6 +506,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"join_hive_626",
"join_map_ppr",
"join_nulls",
+ "join_nullsafe",
"join_rc",
"join_reorder2",
"join_reorder3",
@@ -734,6 +738,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_double",
"udf_E",
"udf_elt",
+ "udf_equal",
"udf_exp",
"udf_field",
"udf_find_in_set",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 1699ffe06ce15..93d00f7c37c9b 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -32,7 +32,7 @@
Spark Project Hive
http://spark.apache.org/
- hive
+ hive
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index d1b9f7b3c1ebc..69383f2e2d86d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -252,9 +252,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
- ShortType, DecimalType, TimestampType)
+ ShortType, DecimalType, TimestampType, BinaryType)
- protected def toHiveString(a: (Any, DataType)): String = a match {
+ protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
@@ -268,6 +268,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "NULL"
case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString
+ case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 8af5c11cec40c..57738aabff176 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -260,7 +260,7 @@ private[hive] case class MetastoreRelation
// org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException
// which indicates the SerDe we used is not Serializable.
- def hiveQlTable = new Table(table)
+ @transient lazy val hiveQlTable = new Table(table)
def hiveQlPartitions = partitions.map { p =>
new Partition(hiveQlTable, p)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index c4ca9f362a04d..e6ab68b563f8d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -925,12 +925,14 @@ private[hive] object HiveQl {
case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
- case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token(DIV(), left :: right:: Nil) =>
+ Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
/* Comparisons */
case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
+ case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right))
case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
diff --git a/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
new file mode 100644
index 0000000000000..17ba0bea723c6
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
@@ -0,0 +1 @@
+0 0 0 1 2
diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
new file mode 100644
index 0000000000000..7b7a9175114ce
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
@@ -0,0 +1 @@
+2.0 0.5 0.3333333333333333 0.002
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 b/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-1-5644ab44e5ba9f2941216b8d5dc33a99 b/sql/hive/src/test/resources/golden/join_nullsafe-1-5644ab44e5ba9f2941216b8d5dc33a99
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 b/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919
new file mode 100644
index 0000000000000..31c409082cc2f
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919
@@ -0,0 +1,2 @@
+NULL 10 10 NULL NULL 10
+100 100 100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-11-3aa243002a5363b84556736ef71613b1 b/sql/hive/src/test/resources/golden/join_nullsafe-11-3aa243002a5363b84556736ef71613b1
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e b/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e
new file mode 100644
index 0000000000000..9b77d13cbaab2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e
@@ -0,0 +1,4 @@
+NULL NULL NULL NULL NULL NULL
+NULL 10 10 NULL NULL 10
+10 NULL NULL 10 10 NULL
+100 100 100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 b/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993
new file mode 100644
index 0000000000000..47c0709d39851
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993
@@ -0,0 +1,12 @@
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+10 NULL NULL 10
+48 NULL NULL NULL
+100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c b/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c
new file mode 100644
index 0000000000000..36ba48516b658
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c
@@ -0,0 +1,12 @@
+NULL NULL NULL NULL
+NULL NULL NULL 35
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+10 NULL NULL 10
+100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da b/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da
new file mode 100644
index 0000000000000..fc1fd198cf8be
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da
@@ -0,0 +1,13 @@
+NULL NULL NULL NULL
+NULL NULL NULL 35
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+10 NULL NULL 10
+48 NULL NULL NULL
+100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 b/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97
new file mode 100644
index 0000000000000..1cc70524f9d6d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97
@@ -0,0 +1,11 @@
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+10 NULL NULL 10
+100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 b/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8
new file mode 100644
index 0000000000000..1cc70524f9d6d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8
@@ -0,0 +1,11 @@
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+10 NULL NULL 10
+100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-18-10b2051e65cac50ee1ea1c138ec192c8 b/sql/hive/src/test/resources/golden/join_nullsafe-18-10b2051e65cac50ee1ea1c138ec192c8
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-19-23ab7ac8229a53d391195be7ca092429 b/sql/hive/src/test/resources/golden/join_nullsafe-19-23ab7ac8229a53d391195be7ca092429
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-2-793e288c9e0971f0bf3f37493f76dc7 b/sql/hive/src/test/resources/golden/join_nullsafe-2-793e288c9e0971f0bf3f37493f76dc7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-20-d6fc260320c577eec9a5db0d4135d224 b/sql/hive/src/test/resources/golden/join_nullsafe-20-d6fc260320c577eec9a5db0d4135d224
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-21-a60dae725ffc543f805242611d99de4e b/sql/hive/src/test/resources/golden/join_nullsafe-21-a60dae725ffc543f805242611d99de4e
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-22-24c80d0f9e3d72c48d947770fa184985 b/sql/hive/src/test/resources/golden/join_nullsafe-22-24c80d0f9e3d72c48d947770fa184985
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-23-3fe6ae20cab3417759dcc654a3a26746 b/sql/hive/src/test/resources/golden/join_nullsafe-23-3fe6ae20cab3417759dcc654a3a26746
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 b/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 b/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 b/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 b/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6
new file mode 100644
index 0000000000000..66482299904bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL 10
+NULL NULL NULL 10
+NULL NULL NULL 35
+NULL NULL NULL 35
+NULL NULL NULL 110
+NULL NULL NULL 110
+NULL NULL NULL 135
+NULL NULL NULL 135
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 NULL 10
+NULL 10 NULL 35
+NULL 10 NULL 110
+NULL 10 NULL 135
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 NULL 10
+NULL 35 NULL 35
+NULL 35 NULL 110
+NULL 35 NULL 135
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 NULL 10
+NULL 110 NULL 35
+NULL 110 NULL 110
+NULL 110 NULL 135
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 NULL 10
+NULL 135 NULL 35
+NULL 135 NULL 110
+NULL 135 NULL 135
+10 NULL 10 NULL
+48 NULL 48 NULL
+100 100 100 100
+110 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 b/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525
new file mode 100644
index 0000000000000..2efbef0484452
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525
@@ -0,0 +1,14 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL 10 NULL 10
+NULL 35 NULL 35
+NULL 110 NULL 110
+NULL 135 NULL 135
+10 NULL 10 NULL
+48 NULL 48 NULL
+100 100 100 100
+110 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf b/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf
new file mode 100644
index 0000000000000..66482299904bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL 10
+NULL NULL NULL 10
+NULL NULL NULL 35
+NULL NULL NULL 35
+NULL NULL NULL 110
+NULL NULL NULL 110
+NULL NULL NULL 135
+NULL NULL NULL 135
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 NULL 10
+NULL 10 NULL 35
+NULL 10 NULL 110
+NULL 10 NULL 135
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 NULL 10
+NULL 35 NULL 35
+NULL 35 NULL 110
+NULL 35 NULL 135
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 NULL 10
+NULL 110 NULL 35
+NULL 110 NULL 110
+NULL 110 NULL 135
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 NULL 10
+NULL 135 NULL 35
+NULL 135 NULL 110
+NULL 135 NULL 135
+10 NULL 10 NULL
+48 NULL 48 NULL
+100 100 100 100
+110 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-3-ae378fc0f875a21884e58fa35a6d52cd b/sql/hive/src/test/resources/golden/join_nullsafe-3-ae378fc0f875a21884e58fa35a6d52cd
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 b/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88
new file mode 100644
index 0000000000000..66482299904bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL 10
+NULL NULL NULL 10
+NULL NULL NULL 35
+NULL NULL NULL 35
+NULL NULL NULL 110
+NULL NULL NULL 110
+NULL NULL NULL 135
+NULL NULL NULL 135
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 NULL 10
+NULL 10 NULL 35
+NULL 10 NULL 110
+NULL 10 NULL 135
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 NULL 10
+NULL 35 NULL 35
+NULL 35 NULL 110
+NULL 35 NULL 135
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 NULL 10
+NULL 110 NULL 35
+NULL 110 NULL 110
+NULL 110 NULL 135
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 NULL 10
+NULL 135 NULL 35
+NULL 135 NULL 110
+NULL 135 NULL 135
+10 NULL 10 NULL
+48 NULL 48 NULL
+100 100 100 100
+110 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf b/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf
new file mode 100644
index 0000000000000..66482299904bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL 10
+NULL NULL NULL 10
+NULL NULL NULL 35
+NULL NULL NULL 35
+NULL NULL NULL 110
+NULL NULL NULL 110
+NULL NULL NULL 135
+NULL NULL NULL 135
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 NULL 10
+NULL 10 NULL 35
+NULL 10 NULL 110
+NULL 10 NULL 135
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 NULL 10
+NULL 35 NULL 35
+NULL 35 NULL 110
+NULL 35 NULL 135
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 NULL 10
+NULL 110 NULL 35
+NULL 110 NULL 110
+NULL 110 NULL 135
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 NULL 10
+NULL 135 NULL 35
+NULL 135 NULL 110
+NULL 135 NULL 135
+10 NULL 10 NULL
+48 NULL 48 NULL
+100 100 100 100
+110 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 b/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87
new file mode 100644
index 0000000000000..ea001a222f357
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87
@@ -0,0 +1,40 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 10 110 NULL
+NULL 10 148 NULL
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+NULL 35 110 NULL
+NULL 35 148 NULL
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 10 NULL
+NULL 110 48 NULL
+NULL 110 110 NULL
+NULL 110 148 NULL
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 10 NULL
+NULL 135 48 NULL
+NULL 135 110 NULL
+NULL 135 148 NULL
+10 NULL NULL 10
+100 100 100 100
+110 NULL NULL 110
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c b/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c
new file mode 100644
index 0000000000000..ea001a222f357
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c
@@ -0,0 +1,40 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 10 110 NULL
+NULL 10 148 NULL
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+NULL 35 110 NULL
+NULL 35 148 NULL
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 10 NULL
+NULL 110 48 NULL
+NULL 110 110 NULL
+NULL 110 148 NULL
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 10 NULL
+NULL 135 48 NULL
+NULL 135 110 NULL
+NULL 135 148 NULL
+10 NULL NULL 10
+100 100 100 100
+110 NULL NULL 110
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 b/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223
new file mode 100644
index 0000000000000..1093bd89f6e3f
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 10 110 NULL
+NULL 10 148 NULL
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+NULL 35 110 NULL
+NULL 35 148 NULL
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 10 NULL
+NULL 110 48 NULL
+NULL 110 110 NULL
+NULL 110 148 NULL
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 10 NULL
+NULL 135 48 NULL
+NULL 135 110 NULL
+NULL 135 148 NULL
+10 NULL NULL 10
+48 NULL NULL NULL
+100 100 100 100
+110 NULL NULL 110
+148 NULL NULL NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc b/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc
new file mode 100644
index 0000000000000..9cf0036674d6e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL 35
+NULL NULL NULL 135
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 10 110 NULL
+NULL 10 148 NULL
+NULL 35 NULL NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+NULL 35 110 NULL
+NULL 35 148 NULL
+NULL 110 NULL NULL
+NULL 110 NULL NULL
+NULL 110 10 NULL
+NULL 110 48 NULL
+NULL 110 110 NULL
+NULL 110 148 NULL
+NULL 135 NULL NULL
+NULL 135 NULL NULL
+NULL 135 10 NULL
+NULL 135 48 NULL
+NULL 135 110 NULL
+NULL 135 148 NULL
+10 NULL NULL 10
+100 100 100 100
+110 NULL NULL 110
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 b/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57
new file mode 100644
index 0000000000000..77f6a8ddd7c28
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL 10
+NULL 35 NULL 35
+NULL 110 NULL 110
+NULL 135 NULL 135
+10 NULL NULL NULL
+10 NULL NULL NULL
+10 NULL 10 NULL
+10 NULL 48 NULL
+10 NULL 110 NULL
+10 NULL 148 NULL
+48 NULL NULL NULL
+48 NULL NULL NULL
+48 NULL 10 NULL
+48 NULL 48 NULL
+48 NULL 110 NULL
+48 NULL 148 NULL
+100 100 100 100
+110 NULL NULL NULL
+110 NULL NULL NULL
+110 NULL 10 NULL
+110 NULL 48 NULL
+110 NULL 110 NULL
+110 NULL 148 NULL
+148 NULL NULL NULL
+148 NULL NULL NULL
+148 NULL 10 NULL
+148 NULL 48 NULL
+148 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 b/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245
new file mode 100644
index 0000000000000..77f6a8ddd7c28
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL 10
+NULL 35 NULL 35
+NULL 110 NULL 110
+NULL 135 NULL 135
+10 NULL NULL NULL
+10 NULL NULL NULL
+10 NULL 10 NULL
+10 NULL 48 NULL
+10 NULL 110 NULL
+10 NULL 148 NULL
+48 NULL NULL NULL
+48 NULL NULL NULL
+48 NULL 10 NULL
+48 NULL 48 NULL
+48 NULL 110 NULL
+48 NULL 148 NULL
+100 100 100 100
+110 NULL NULL NULL
+110 NULL NULL NULL
+110 NULL 10 NULL
+110 NULL 48 NULL
+110 NULL 110 NULL
+110 NULL 148 NULL
+148 NULL NULL NULL
+148 NULL NULL NULL
+148 NULL 10 NULL
+148 NULL 48 NULL
+148 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 b/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762
new file mode 100644
index 0000000000000..77f6a8ddd7c28
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL 10
+NULL 35 NULL 35
+NULL 110 NULL 110
+NULL 135 NULL 135
+10 NULL NULL NULL
+10 NULL NULL NULL
+10 NULL 10 NULL
+10 NULL 48 NULL
+10 NULL 110 NULL
+10 NULL 148 NULL
+48 NULL NULL NULL
+48 NULL NULL NULL
+48 NULL 10 NULL
+48 NULL 48 NULL
+48 NULL 110 NULL
+48 NULL 148 NULL
+100 100 100 100
+110 NULL NULL NULL
+110 NULL NULL NULL
+110 NULL 10 NULL
+110 NULL 48 NULL
+110 NULL 110 NULL
+110 NULL 148 NULL
+148 NULL NULL NULL
+148 NULL NULL NULL
+148 NULL 10 NULL
+148 NULL 48 NULL
+148 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec b/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec
new file mode 100644
index 0000000000000..77f6a8ddd7c28
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec
@@ -0,0 +1,42 @@
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL NULL 48 NULL
+NULL NULL 110 NULL
+NULL NULL 110 NULL
+NULL NULL 148 NULL
+NULL NULL 148 NULL
+NULL 10 NULL 10
+NULL 35 NULL 35
+NULL 110 NULL 110
+NULL 135 NULL 135
+10 NULL NULL NULL
+10 NULL NULL NULL
+10 NULL 10 NULL
+10 NULL 48 NULL
+10 NULL 110 NULL
+10 NULL 148 NULL
+48 NULL NULL NULL
+48 NULL NULL NULL
+48 NULL 10 NULL
+48 NULL 48 NULL
+48 NULL 110 NULL
+48 NULL 148 NULL
+100 100 100 100
+110 NULL NULL NULL
+110 NULL NULL NULL
+110 NULL 10 NULL
+110 NULL 48 NULL
+110 NULL 110 NULL
+110 NULL 148 NULL
+148 NULL NULL NULL
+148 NULL NULL NULL
+148 NULL 10 NULL
+148 NULL 48 NULL
+148 NULL 110 NULL
+148 NULL 148 NULL
+200 200 200 200
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 b/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185
new file mode 100644
index 0000000000000..1cc70524f9d6d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185
@@ -0,0 +1,11 @@
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
+10 NULL NULL 10
+100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-40-3c868718e4c120cb9a72ab7318c75be3 b/sql/hive/src/test/resources/golden/join_nullsafe-40-3c868718e4c120cb9a72ab7318c75be3
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 b/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811
new file mode 100644
index 0000000000000..421049d6e509e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811
@@ -0,0 +1,9 @@
+NULL NULL NULL NULL
+NULL NULL 10 NULL
+NULL NULL 48 NULL
+NULL 10 NULL NULL
+NULL 10 10 NULL
+NULL 10 48 NULL
+NULL 35 NULL NULL
+NULL 35 10 NULL
+NULL 35 48 NULL
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-5-1e393de94850e92b3b00536aacc9371f b/sql/hive/src/test/resources/golden/join_nullsafe-5-1e393de94850e92b3b00536aacc9371f
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 b/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0
new file mode 100644
index 0000000000000..aec3122cae5f9
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0
@@ -0,0 +1,2 @@
+10 NULL NULL 10 10 NULL
+100 100 100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-7-a3ad3cc301d9884898d3e6ab6c792d4c b/sql/hive/src/test/resources/golden/join_nullsafe-7-a3ad3cc301d9884898d3e6ab6c792d4c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac b/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac
new file mode 100644
index 0000000000000..30db79efa79b4
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac
@@ -0,0 +1,29 @@
+NULL NULL NULL NULL NULL NULL
+NULL NULL NULL NULL NULL 10
+NULL NULL NULL NULL NULL 35
+NULL NULL 10 NULL NULL NULL
+NULL NULL 10 NULL NULL 10
+NULL NULL 10 NULL NULL 35
+NULL NULL 48 NULL NULL NULL
+NULL NULL 48 NULL NULL 10
+NULL NULL 48 NULL NULL 35
+NULL 10 NULL NULL NULL NULL
+NULL 10 NULL NULL NULL 10
+NULL 10 NULL NULL NULL 35
+NULL 10 10 NULL NULL NULL
+NULL 10 10 NULL NULL 10
+NULL 10 10 NULL NULL 35
+NULL 10 48 NULL NULL NULL
+NULL 10 48 NULL NULL 10
+NULL 10 48 NULL NULL 35
+NULL 35 NULL NULL NULL NULL
+NULL 35 NULL NULL NULL 10
+NULL 35 NULL NULL NULL 35
+NULL 35 10 NULL NULL NULL
+NULL 35 10 NULL NULL 10
+NULL 35 10 NULL NULL 35
+NULL 35 48 NULL NULL NULL
+NULL 35 48 NULL NULL 10
+NULL 35 48 NULL NULL 35
+10 NULL NULL 10 10 NULL
+100 100 100 100 100 100
diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-9-88f6f40959b0d2faabd9d4b3cd853809 b/sql/hive/src/test/resources/golden/join_nullsafe-9-88f6f40959b0d2faabd9d4b3cd853809
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 b/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876
new file mode 100644
index 0000000000000..9b9b6312a269a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876
@@ -0,0 +1 @@
+a = b - Returns TRUE if a equals b and false otherwise
diff --git a/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 b/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892
new file mode 100644
index 0000000000000..30fdf50f62e4e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892
@@ -0,0 +1,2 @@
+a = b - Returns TRUE if a equals b and false otherwise
+Synonyms: ==
diff --git a/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 b/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053
new file mode 100644
index 0000000000000..d6b4c860778b7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053
@@ -0,0 +1 @@
+a == b - Returns TRUE if a equals b and false otherwise
diff --git a/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 b/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33
new file mode 100644
index 0000000000000..71e55d6d638a6
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33
@@ -0,0 +1,2 @@
+a == b - Returns TRUE if a equals b and false otherwise
+Synonyms: =
diff --git a/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 b/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16
new file mode 100644
index 0000000000000..015c417bc68f0
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16
@@ -0,0 +1 @@
+false false true true NULL NULL NULL NULL NULL
diff --git a/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 b/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7
new file mode 100644
index 0000000000000..aa7b4b51edea7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7
@@ -0,0 +1 @@
+a <=> b - Returns same result with EQUAL(=) operator for non-null operands, but returns TRUE if both are NULL, FALSE if one of the them is NULL
diff --git a/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e b/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e
new file mode 100644
index 0000000000000..aa7b4b51edea7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e
@@ -0,0 +1 @@
+a <=> b - Returns same result with EQUAL(=) operator for non-null operands, but returns TRUE if both are NULL, FALSE if one of the them is NULL
diff --git a/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 b/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27
new file mode 100644
index 0000000000000..05292fb23192d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27
@@ -0,0 +1 @@
+false false true true true false false false false
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 08ef4d9b6bb93..b4dbf2b115799 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -350,12 +350,6 @@ abstract class HiveComparisonTest
val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n")
- println("hive output")
- hive.foreach(println)
-
- println("catalyst printout")
- catalyst.foreach(println)
-
if (recomputeCache) {
logger.warn(s"Clearing cache files for failed test $testCaseName")
hiveCacheFiles.foreach(_.delete())
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6f36a4f8cb905..a022a1e2dc70e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -52,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT * FROM src WHERE key Between 1 and 2")
createQueryTest("div",
- "SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1")
+ "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1")
+
+ createQueryTest("division",
+ "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")
test("Query expressed in SQL") {
assert(sql("SELECT 1").collect() === Array(Seq(1)))
@@ -416,10 +419,10 @@ class HiveQuerySuite extends HiveComparisonTest {
hql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
- hql("set mapred.reduce.tasks=20")
- assert(get("mapred.reduce.tasks", "0") == "20")
- hql("set mapred.reduce.tasks = 40")
- assert(get("mapred.reduce.tasks", "0") == "40")
+ hql("set some.property=20")
+ assert(get("some.property", "0") == "20")
+ hql("set some.property = 40")
+ assert(get("some.property", "0") == "40")
hql(s"set $testKey=$testVal")
assert(get(testKey, "0") == testVal)
@@ -433,63 +436,61 @@ class HiveQuerySuite extends HiveComparisonTest {
val testKey = "spark.sql.key.usedfortestonly"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
- def collectResults(rdd: SchemaRDD): Set[(String, String)] =
- rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet
clear()
// "set" itself returns all config variables currently specified in SQLConf.
assert(hql("SET").collect().size == 0)
- assertResult(Set(testKey -> testVal)) {
- collectResults(hql(s"SET $testKey=$testVal"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ hql(s"SET $testKey=$testVal").collect().map(_.getString(0))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Set(testKey -> testVal)) {
- collectResults(hql("SET"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ hql(s"SET $testKey=$testVal").collect().map(_.getString(0))
}
hql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
- collectResults(hql("SET"))
+ assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
+ hql(s"SET").collect().map(_.getString(0))
}
// "set key"
- assertResult(Set(testKey -> testVal)) {
- collectResults(hql(s"SET $testKey"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ hql(s"SET $testKey").collect().map(_.getString(0))
}
- assertResult(Set(nonexistentKey -> "")) {
- collectResults(hql(s"SET $nonexistentKey"))
+ assertResult(Array(s"$nonexistentKey=")) {
+ hql(s"SET $nonexistentKey").collect().map(_.getString(0))
}
// Assert that sql() should have the same effects as hql() by repeating the above using sql().
clear()
assert(sql("SET").collect().size == 0)
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql(s"SET $testKey=$testVal"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql("SET"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ sql("SET").collect().map(_.getString(0))
}
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
- collectResults(sql("SET"))
+ assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
+ sql("SET").collect().map(_.getString(0))
}
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql(s"SET $testKey"))
+ assertResult(Array(s"$testKey=$testVal")) {
+ sql(s"SET $testKey").collect().map(_.getString(0))
}
- assertResult(Set(nonexistentKey -> "")) {
- collectResults(sql(s"SET $nonexistentKey"))
+ assertResult(Array(s"$nonexistentKey=")) {
+ sql(s"SET $nonexistentKey").collect().map(_.getString(0))
}
clear()
diff --git a/streaming/pom.xml b/streaming/pom.xml
index f60697ce745b7..b99f306b8f2cc 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming_2.10
- streaming
+ streaming
jar
Spark Project Streaming
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index 40da31318942e..1a47089e513c4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -133,17 +133,17 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
val numOldValues = oldRDDs.size
val numNewValues = newRDDs.size
- val mergeValues = (seqOfValues: Seq[Seq[V]]) => {
- if (seqOfValues.size != 1 + numOldValues + numNewValues) {
+ val mergeValues = (arrayOfValues: Array[Iterable[V]]) => {
+ if (arrayOfValues.size != 1 + numOldValues + numNewValues) {
throw new Exception("Unexpected number of sequences of reduced values")
}
// Getting reduced values "old time steps" that will be removed from current window
- val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head)
+ val oldValues = (1 to numOldValues).map(i => arrayOfValues(i)).filter(!_.isEmpty).map(_.head)
// Getting reduced values "new time steps"
val newValues =
- (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
+ (1 to numNewValues).map(i => arrayOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
- if (seqOfValues(0).isEmpty) {
+ if (arrayOfValues(0).isEmpty) {
// If previous window's reduce value does not exist, then at least new values should exist
if (newValues.isEmpty) {
throw new Exception("Neither previous window has value for key, nor new values found. " +
@@ -153,7 +153,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
newValues.reduce(reduceF) // return
} else {
// Get the previous window's reduced value
- var tempValue = seqOfValues(0).head
+ var tempValue = arrayOfValues(0).head
// If old values exists, then inverse reduce then from previous value
if (!oldValues.isEmpty) {
tempValue = invReduceF(tempValue, oldValues.reduce(reduceF))
@@ -166,7 +166,8 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
}
}
- val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues)
+ val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K, Array[Iterable[V]])]]
+ .mapValues(mergeValues)
if (filterFunc.isDefined) {
Some(mergedValuesRDD.filter(filterFunc.get))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
index 743be58950c09..1868a1ebc7b4a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
@@ -68,13 +68,13 @@ object ActorSupervisorStrategy {
* should be same.
*/
@DeveloperApi
-trait ActorHelper {
+trait ActorHelper extends Logging{
self: Actor => // to ensure that this can be added to Actor classes only
/** Store an iterator of received data as a data block into Spark's memory. */
def store[T](iter: Iterator[T]) {
- println("Storing iterator")
+ logDebug("Storing iterator")
context.parent ! IteratorData(iter)
}
@@ -84,6 +84,7 @@ trait ActorHelper {
* that Spark is configured to use.
*/
def store(bytes: ByteBuffer) {
+ logDebug("Storing Bytes")
context.parent ! ByteBufferData(bytes)
}
@@ -93,7 +94,7 @@ trait ActorHelper {
* being pushed into Spark's memory.
*/
def store[T](item: T) {
- println("Storing item")
+ logDebug("Storing item")
context.parent ! SingleItemData(item)
}
}
@@ -157,15 +158,16 @@ private[streaming] class ActorReceiver[T: ClassTag](
def receive = {
case IteratorData(iterator) =>
- println("received iterator")
+ logDebug("received iterator")
store(iterator.asInstanceOf[Iterator[T]])
case SingleItemData(msg) =>
- println("received single")
+ logDebug("received single")
store(msg.asInstanceOf[T])
n.incrementAndGet
case ByteBufferData(bytes) =>
+ logDebug("received bytes")
store(bytes)
case props: Props =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index ce8316bb14891..d934b9cbfc3e8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -110,8 +110,7 @@ private[streaming] class ReceiverSupervisorImpl(
) {
val blockId = optionalBlockId.getOrElse(nextBlockId)
val time = System.currentTimeMillis
- blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]],
- storageLevel, tellMaster = true)
+ blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true)
logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata)
}
@@ -124,7 +123,7 @@ private[streaming] class ReceiverSupervisorImpl(
) {
val blockId = optionalBlockId.getOrElse(nextBlockId)
val time = System.currentTimeMillis
- blockManager.put(blockId, iterator, storageLevel, tellMaster = true)
+ blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true)
logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
reportPushedBlock(blockId, -1, optionalMetadata)
}
diff --git a/tools/pom.xml b/tools/pom.xml
index c0ee8faa7a615..97abb6b2b63e0 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -27,7 +27,7 @@
org.apache.spark
spark-tools_2.10
- tools
+ tools
jar
Spark Project Tools
diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml
index 5b13a1f002d6e..51744ece0412d 100644
--- a/yarn/alpha/pom.xml
+++ b/yarn/alpha/pom.xml
@@ -24,7 +24,7 @@
../pom.xml
- yarn-alpha
+ yarn-alpha
org.apache.spark
diff --git a/yarn/pom.xml b/yarn/pom.xml
index efb473aa1b261..3faaf053634d6 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -29,7 +29,7 @@
pom
Spark Project YARN Parent POM
- yarn
+ yarn
diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml
index ceaf9f9d71001..b6c8456d06684 100644
--- a/yarn/stable/pom.xml
+++ b/yarn/stable/pom.xml
@@ -24,7 +24,7 @@
../pom.xml
- yarn-stable
+ yarn-stable
org.apache.spark