From b9df8af62e8d7b263a668dfb6e9668ab4294ea37 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Wed, 8 Oct 2014 23:45:17 -0700 Subject: [PATCH 01/38] [SPARK-2805] Upgrade to akka 2.3.4 Upgrade to akka 2.3.4 Author: Anand Avati Closes #1685 from avati/SPARK-1812-akka-2.3 and squashes the following commits: 57a2315 [Anand Avati] SPARK-1812: streaming - remove tests which depend on akka.actor.IO 2a551d3 [Anand Avati] SPARK-1812: core - upgrade to akka 2.3.4 --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 2 +- .../spark/streaming/InputStreamsSuite.scala | 71 ------------------- 6 files changed, 6 insertions(+), 77 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index 7756c89b00cad..3b6d4ecbae2c1 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..6107fcdc447b6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -144,59 +142,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -378,22 +323,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 86b392942daf61fed2ff7490178b128107a0e856 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 9 Oct 2014 00:00:24 -0700 Subject: [PATCH 02/38] [SPARK-3844][UI] Truncate appName in WebUI if it is too long Truncate appName in WebUI if it is too long. Author: Xiangrui Meng Closes #2707 from mengxr/truncate-app-name and squashes the following commits: 87834ce [Xiangrui Meng] move scala import below java c7111dc [Xiangrui Meng] truncate appName in WebUI if it is too long --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index be69060fc3bf8..32e6b15bb0999 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.text.SimpleDateFormat import java.util.{Locale, Date} import scala.xml.Node + import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ @@ -169,6 +170,7 @@ private[spark] object UIUtils extends Logging { refreshInterval: Option[Int] = None): Seq[Node] = { val appName = activeTab.appName + val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • {tab.name} @@ -187,7 +189,9 @@ private[spark] object UIUtils extends Logging { - +
    From 13cab5ba44e2f8d2d2204b3b0d39d7c23a819bdb Mon Sep 17 00:00:00 2001 From: nartz Date: Thu, 9 Oct 2014 00:02:11 -0700 Subject: [PATCH 03/38] add spark.driver.memory to config docs It took me a minute to track this down, so I thought it could be useful to have it in the docs. I'm unsure if 512mb is the default for spark.driver.memory? Also - there could be a better value for the 'description' to differentiate it from spark.executor.memory. Author: nartz Author: Nathan Artz Closes #2410 from nartz/docs/add-spark-driver-memory-to-config-docs and squashes the following commits: a2f6c62 [nartz] Update configuration.md 74521b8 [Nathan Artz] add spark.driver.memory to config docs --- docs/configuration.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 1c33855365170..f311f0d2a6206 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -103,6 +103,14 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.memory + 512m + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + + spark.serializer org.apache.spark.serializer.
    JavaSerializer From 14f222f7f76cc93633aae27a94c0e556e289ec56 Mon Sep 17 00:00:00 2001 From: Qiping Li Date: Thu, 9 Oct 2014 01:36:58 -0700 Subject: [PATCH 04/38] [SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes). ### Implementation Details Each node now has a `impurity` field and the `predict` is changed from type `Double` to type `Predict`(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1. If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In `binsToBestSplit`, if current node is top node(level == 0), we calculate impurity and predict first. after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way. CC mengxr manishamde jkbradley, please help me review this, thanks. Author: Qiping Li Closes #2708 from chouqin/avoid-agg and squashes the following commits: 8e269ea [Qiping Li] adjust code and comments eefeef1 [Qiping Li] adjust comments and check child nodes' impurity c41b1b6 [Qiping Li] fix pyspark unit test 7ad7a71 [Qiping Li] fix unit test 822c912 [Qiping Li] add comments and unit test e41d715 [Qiping Li] fix bug in test suite 6cc0333 [Qiping Li] SPARK-3158: Avoid 1 extra aggregation for DecisionTree training --- .../spark/mllib/tree/DecisionTree.scala | 97 +++++++++++------ .../tree/model/InformationGainStats.scala | 9 +- .../apache/spark/mllib/tree/model/Node.scala | 37 +++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 102 ++++++++++++++++-- 4 files changed, 197 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b311d10023894..03eeaa707715b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging { Some(mutableNodeToFeatures.toMap) } + // array of nodes to train indexed by node index in group + val nodes = new Array[Node](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + // Calculate best splits for all nodes in the group timer.start("chooseSplits") @@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging { // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap() @@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging { // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) - node.predict = predict.predict + node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) + node.impurity = stats.impurity logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) - node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) - node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - nodeQueue.enqueue((treeIndex, node.rightNode.get)) + val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + } + logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + @@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): InformationGainStats = { + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - - val impurity = parentNodeAgg.calculate() - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) } /** @@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging { * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split * @param rightImpurityCalculator right node aggregates for a split - * @return predict value for current node + * @return predict value and impurity for current node */ - private def calculatePredict( + private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): Predict = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() - new Predict(predict, prob) + (predict, impurity) } /** @@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging { private def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + featuresForNode: Option[Array[Int]], + node: Node): (Split, InformationGainStats, Predict) = { - // calculate predict only once - var predict: Option[Predict] = None + // calculate predict and impurity if current node is top node + val level = Node.indexToLevel(node.id) + var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predict, node.impurity)) + } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = @@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - assert(predict.isDefined, "must calculate predict for each node") - - (bestSplit, bestSplitStats, predict.get) + (bestSplit, bestSplitStats, predictWithImpurity.get._1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index a89e71e115806..9a50ecb550c38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double) extends Serializable { + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" @@ -58,5 +62,6 @@ private[tree] object InformationGainStats { * denote that current split doesn't satisfies minimum info gain or * minimum number of instances per node. */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, + new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 56c3e25d9285f..2179da8dbe03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector * * @param id integer node id, from 1 * @param predict predicted value at the node - * @param isLeaf whether the leaf is a node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf * @param split split to calculate left and right nodes * @param leftNode left child * @param rightNode right child @@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - var predict: Double, + var predict: Predict, + var impurity: Double, var isLeaf: Boolean, var split: Option[Split], var leftNode: Option[Node], @@ -49,7 +51,7 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "split = " + split + ", stats = " + stats + "impurity = " + impurity + "split = " + split + ", stats = " + stats /** * build the left node and right nodes if not leaf @@ -62,6 +64,7 @@ class Node ( logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) + logDebug("impurity = " + impurity) if (!isLeaf) { leftNode = Some(nodes(Node.leftChildIndex(id))) rightNode = Some(nodes(Node.rightChildIndex(id))) @@ -77,7 +80,7 @@ class Node ( */ def predict(features: Vector) : Double = { if (isLeaf) { - predict + predict.predict } else{ if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { @@ -109,7 +112,7 @@ class Node ( } else { Some(rightNode.get.deepCopy()) } - new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) } /** @@ -154,7 +157,7 @@ class Node ( } val prefix: String = " " * indentFactor if (isLeaf) { - prefix + s"Predict: $predict\n" + prefix + s"Predict: ${predict.predict}\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + @@ -170,7 +173,27 @@ private[tree] object Node { /** * Return a node with the given node id (but nothing else set). */ - def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, + false, None, None, None, None) + + /** + * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. + * This is used in `DecisionTree.findBestSplits` to construct child nodes + * after finding the best splits for parent nodes. + * Other fields are set at next level. + * @param nodeIndex integer node id, from 1 + * @param predict predicted value at the node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf + * @return new node instance + */ + def apply( + nodeIndex: Int, + predict: Predict, + impurity: Double, + isLeaf: Boolean): Node = { + new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) + } /** * Return the index of the left child of this node. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a48ed71a1c5fc..98a72b0c4d750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) assert(stats.impurity > 0.2) } @@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 0.6) + assert(rootNode.predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 0) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Second level node building with vs. without groups") { @@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats1.impurity === stats2.impurity) assert(stats1.leftImpurity === stats2.leftImpurity) assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict === children2(i).predict) + assert(children1(i).predict.predict === children2(i).predict.predict) } } @@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) } + + test("Avoid aggregation on the last level") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } } object DecisionTreeSuite { From 1e0aa4deba65aa1241b9a30edb82665eae27242f Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 9 Oct 2014 09:22:32 -0700 Subject: [PATCH 05/38] [Minor] use norm operator after breeze 0.10 upgrade cc mengxr Author: GuoQiang Li Closes #2730 from witgo/SPARK-3856 and squashes the following commits: 2cffce1 [GuoQiang Li] use norm operator after breeze 0.10 upgrade --- .../spark/mllib/feature/NormalizerSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index fb76dccfdf79e..2bf9d9816ae45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite +import breeze.linalg.{norm => brzNorm} + import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) From 73bf3f2e0c03216aa29c25fea2d97205b5977903 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Oct 2014 11:27:21 -0700 Subject: [PATCH 06/38] [SPARK-3741] Make ConnectionManager propagate errors properly and add mo... ...re logs to avoid Executors swallowing errors This PR made the following changes: * Register a callback to `Connection` so that the error will be propagated properly. * Add more logs so that the errors won't be swallowed by Executors. * Use trySuccess/tryFailure because `Promise` doesn't allow to call success/failure more than once. Author: zsxwing Closes #2593 from zsxwing/SPARK-3741 and squashes the following commits: 1d5aed5 [zsxwing] Fix naming 0b8a61c [zsxwing] Merge branch 'master' into SPARK-3741 764aec5 [zsxwing] [SPARK-3741] Make ConnectionManager propagate errors properly and add more logs to avoid Executors swallowing errors --- .../apache/spark/network/nio/Connection.scala | 35 +-- .../spark/network/nio/ConnectionManager.scala | 206 +++++++++++++----- 2 files changed, 172 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..4f6f5e235811d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,14 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList import org.apache.spark._ +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +54,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +584,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..6b00190c5eccc 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,8 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,14 +53,23 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } + + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } @@ -72,14 +83,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -153,17 +182,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +223,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +256,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +295,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +497,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +526,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +541,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +559,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -651,7 +708,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -770,6 +827,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +845,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -807,9 +893,11 @@ private[nio] class ConnectionManager( override def run(): Unit = { messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } }) } } @@ -817,15 +905,23 @@ private[nio] class ConnectionManager( val status = new MessageStatus(message, connectionManagerId, s => { timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val e = new IOException( + "sendMessageReliably failed with ACK that signalled a remote error") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } }) From b77a02f41c60d869f48b65e72ed696c05b30bc48 Mon Sep 17 00:00:00 2001 From: Vida Ha Date: Thu, 9 Oct 2014 13:13:31 -0700 Subject: [PATCH 07/38] [SPARK-3752][SQL]: Add tests for different UDF's Author: Vida Ha Closes #2621 from vidaha/vida/SPARK-3752 and squashes the following commits: d7fdbbc [Vida Ha] Add tests for different UDF's --- .../hive/execution/UDFIntegerToString.java | 26 ++++ .../sql/hive/execution/UDFListListInt.java | 51 ++++++++ .../sql/hive/execution/UDFListString.java | 38 ++++++ .../sql/hive/execution/UDFStringString.java | 26 ++++ .../sql/hive/execution/UDFTwoListList.java | 28 +++++ .../sql/hive/execution/HiveUdfSuite.scala | 111 +++++++++++++++--- 6 files changed, 265 insertions(+), 15 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java new file mode 100644 index 0000000000000..6c4f378bc5471 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFIntegerToString extends UDF { + public String evaluate(Integer i) { + return i.toString(); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java new file mode 100644 index 0000000000000..d2d39a8c4dc28 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; + +public class UDFListListInt extends UDF { + /** + * + * @param obj + * SQL schema: array> + * Java Type: List> + * @return + */ + public long evaluate(Object obj) { + if (obj == null) { + return 0l; + } + List listList = (List) obj; + long retVal = 0; + for (List aList : listList) { + @SuppressWarnings("unchecked") + List list = (List) aList; + @SuppressWarnings("unchecked") + Integer someInt = (Integer) list.get(1); + try { + retVal += (long) (someInt.intValue()); + } catch (NullPointerException e) { + System.out.println(e); + } + } + return retVal; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java new file mode 100644 index 0000000000000..efd34df293c88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; +import org.apache.commons.lang.StringUtils; + +public class UDFListString extends UDF { + + public String evaluate(Object a) { + if (a == null) { + return null; + } + @SuppressWarnings("unchecked") + List s = (List) a; + + return StringUtils.join(s, ','); + } + + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java new file mode 100644 index 0000000000000..a369188d471e8 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFStringString extends UDF { + public String evaluate(String s1, String s2) { + return s1 + " " + s2; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java new file mode 100644 index 0000000000000..0165591a7ce78 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFTwoListList extends UDF { + public String evaluate(Object o1, Object o2) { + UDFListListInt udf = new UDFListListInt(); + + return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2)); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index e4324e9528f9b..872f28d514efe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -17,33 +17,37 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataOutput, DataInput} +import java.io.{DataInput, DataOutput} import java.util import java.util.Properties -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} - -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject - -import org.apache.spark.sql.Row +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.io.Writable +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) +// Case classes for the custom UDF's. +case class IntegerCaseClass(i: Int) +case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) +case class StringCaseClass(s: String) +case class ListStringCaseClass(l: Seq[String]) + /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends HiveComparisonTest { +class HiveUdfSuite extends QueryTest { + import TestHive._ test("spark sql udf test that returns a struct") { registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest { } test("SPARK-2693 udaf aggregates test") { - assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) + } + + test("UDFIntegerToString") { + val testData = TestHive.sparkContext.parallelize( + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + testData.registerTempTable("integerTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + checkAnswer( + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + Seq(Seq("1"), Seq("2"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + + TestHive.reset() + } + + test("UDFListListInt") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + testData.registerTempTable("listListIntTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + checkAnswer( + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + Seq(Seq(0), Seq(2), Seq(13))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + + TestHive.reset() + } + + test("UDFListString") { + val testData = TestHive.sparkContext.parallelize( + ListStringCaseClass(Seq("a", "b", "c")) :: + ListStringCaseClass(Seq("d", "e")) :: Nil) + testData.registerTempTable("listStringTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + checkAnswer( + sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + Seq(Seq("a,b,c"), Seq("d,e"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + + TestHive.reset() + } + + test("UDFStringString") { + val testData = TestHive.sparkContext.parallelize( + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + testData.registerTempTable("stringTable") + + sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + checkAnswer( + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + Seq(Seq("hello world"), Seq("hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + + TestHive.reset() + } + + test("UDFTwoListList") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: + Nil) + testData.registerTempTable("TwoListTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + checkAnswer( + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + + TestHive.reset() } } From 752e90f15e0bb82d283f05eff08df874b48caed9 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Thu, 9 Oct 2014 12:59:14 -0700 Subject: [PATCH 08/38] [SPARK-3711][SQL] Optimize where in clause filter queries The In case class is replaced by a InSet class in case all the filters are literals, which uses a hashset instead of Sequence, thereby giving significant performance improvement (earlier the seq was using a worst case linear match (exists method) since expressions were assumed in the filter list) . Maximum improvement should be visible in case small percentage of large data matches the filter list. Author: Yash Datta Closes #2561 from saucam/branch-1.1 and squashes the following commits: 4bf2d19 [Yash Datta] SPARK-3711: 1. Fix code style and import order 2. Fix optimization condition 3. Add tests for null in filter list 4. Add test case that optimization is not triggered in case of attributes in filter list afedbcd [Yash Datta] SPARK-3711: 1. Add test cases for InSet class in ExpressionEvaluationSuite 2. Add class OptimizedInSuite on the lines of ConstantFoldingSuite, for the optimized In clause 0fc902f [Yash Datta] SPARK-3711: UnaryMinus will be handled by constantFolding bd84c67 [Yash Datta] SPARK-3711: Incorporate review comments. Move optimization of In clause to Optimizer.scala by adding a rule. Add appropriate comments 430f5d1 [Yash Datta] SPARK-3711: Optimize the filter list in case of negative values as well bee98aa [Yash Datta] SPARK-3711: Optimize where in clause filter queries --- .../sql/catalyst/expressions/predicates.scala | 19 ++++- .../sql/catalyst/optimizer/Optimizer.scala | 18 ++++- .../ExpressionEvaluationSuite.scala | 21 +++++ .../catalyst/optimizer/OptimizeInSuite.scala | 76 +++++++++++++++++++ 4 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 329af332d0fa1..1e22b2d03c672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType - object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = apply(BindReferences.bindReference(expression, inputSchema)) @@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } +/** + * Optimized version of In clause, when all filter values of In clause are + * static. + */ +case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) + extends Predicate { + + def children = child + + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + + override def eval(input: Row): Any = { + hset.contains(value.eval(input)) + } +} + case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 636d0b95583e4..3693b41404fd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + OptimizeIn) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -273,6 +275,20 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]] + * which is much faster + */ +object OptimizeIn extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet, v +: list) + } + } +} + /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides. * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 63931af4bac3d..692ed78a7292c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import scala.collection.immutable.HashSet + import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalautils.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ + /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -145,6 +148,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS, one +: s), true) + checkEvaluation(InSet(two, hS, two +: s), true) + checkEvaluation(InSet(two, nS, two +: nullS), true) + checkEvaluation(InSet(nl, nS, nl +: nullS), true) + checkEvaluation(InSet(three, hS, three +: s), false) + checkEvaluation(InSet(three, nS, three +: nullS), false) + checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) + } + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala new file mode 100644 index 0000000000000..97a78ec971c39 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types._ + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class OptimizeInSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + OptimizeIn) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("OptimizedIn test: In clause optimized to InSet") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, + UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: In clause not optimized in case filter has attributes") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + comparePlans(optimized, correctAnswer) + } +} From 2c8851343a2e4d1d5b3a2b959eaa651a92982a72 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 9 Oct 2014 13:22:36 -0700 Subject: [PATCH 09/38] [SPARK-3806][SQL] Minor fix for CliSuite To fix two issues in CliSuite 1 CliSuite throw IndexOutOfBoundsException: Exception in thread "Thread-6" java.lang.IndexOutOfBoundsException: 6 at scala.collection.mutable.ResizableArray$class.apply(ResizableArray.scala:43) at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:47) at org.apache.spark.sql.hive.thriftserver.CliSuite.org$apache$spark$sql$hive$thriftserver$CliSuite$$captureOutput$1(CliSuite.scala:67) at org.apache.spark.sql.hive.thriftserver.CliSuite$$anonfun$4.apply(CliSuite.scala:78) at org.apache.spark.sql.hive.thriftserver.CliSuite$$anonfun$4.apply(CliSuite.scala:78) at scala.sys.process.ProcessLogger$$anon$1.out(ProcessLogger.scala:96) at scala.sys.process.BasicIO$$anonfun$processOutFully$1.apply(BasicIO.scala:135) at scala.sys.process.BasicIO$$anonfun$processOutFully$1.apply(BasicIO.scala:135) at scala.sys.process.BasicIO$.readFully$1(BasicIO.scala:175) at scala.sys.process.BasicIO$.processLinesFully(BasicIO.scala:179) at scala.sys.process.BasicIO$$anonfun$processFully$1.apply(BasicIO.scala:164) at scala.sys.process.BasicIO$$anonfun$processFully$1.apply(BasicIO.scala:162) at scala.sys.process.ProcessBuilderImpl$Simple$$anonfun$3.apply$mcV$sp(ProcessBuilderImpl.scala:73) at scala.sys.process.ProcessImpl$Spawn$$anon$1.run(ProcessImpl.scala:22) Actually, it is the Mutil-Threads lead to this problem. 2 Using ```line.startsWith``` instead ```line.contains``` to assert expected answer. This is a tiny bug in CliSuite, for test case "Simple commands", there is a expected answers "5", if we use ```contains``` that means output like "14/10/06 11:```5```4:36 INFO CliDriver: Time taken: 1.078 seconds" or "14/10/06 11:54:36 INFO StatsReportListener: 0% ```5```% 10% 25% 50% 75% 90% 95% 100%" will make the assert true. Author: scwf Closes #2666 from scwf/clisuite and squashes the following commits: 11430db [scwf] fix-clisuite --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3475c2c9db080..d68dd090b5e6c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -62,9 +62,11 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def captureOutput(source: String)(line: String) { buffer += s"$source> $line" - if (line.contains(expectedAnswers(next.get()))) { - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) + if (next.get() < expectedAnswers.size) { + if (line.startsWith(expectedAnswers(next.get()))) { + if (next.incrementAndGet() == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) + } } } } From e7edb723d22869f228b838fd242bf8e6fe73ee19 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Thu, 9 Oct 2014 13:46:26 -0700 Subject: [PATCH 10/38] [SPARK-3868][PySpark] Hard to recognize which module is tested from unit-tests.log ./python/run-tests script display messages about which test it is running currently on stdout but not write them on unit-tests.log. It is harder for us to recognize what test programs were executed and which test was failed. Author: cocoatomo Closes #2724 from cocoatomo/issues/3868-display-testing-module-name and squashes the following commits: c63d9fa [cocoatomo] [SPARK-3868][PySpark] Hard to recognize which module is tested from unit-tests.log --- python/run-tests | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/run-tests b/python/run-tests index 63395f72788f9..f6a96841175e8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -25,16 +25,17 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" cd "$FWDIR/python" FAILED=0 +LOG_FILE=unit-tests.log -rm -f unit-tests.log +rm -f $LOG_FILE # Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { - echo "Running test: $1" + echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE FAILED=$((PIPESTATUS[0]||$FAILED)) From ec4d40e48186af18e25517e0474020720645f583 Mon Sep 17 00:00:00 2001 From: Mike Timper Date: Thu, 9 Oct 2014 14:02:27 -0700 Subject: [PATCH 11/38] [SPARK-3853][SQL] JSON Schema support for Timestamp fields In JSONRDD.scala, add 'case TimestampType' in the enforceCorrectType function and a toTimestamp function. Author: Mike Timper Closes #2720 from mtimper/master and squashes the following commits: 9386ab8 [Mike Timper] Fix and tests for SPARK-3853 --- .../main/scala/org/apache/spark/sql/json/JsonRDD.scala | 10 ++++++++++ .../scala/org/apache/spark/sql/json/JsonSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0f27fd13e7379..fbc2965e61e92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.json import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal +import java.sql.Timestamp import com.fasterxml.jackson.databind.ObjectMapper @@ -361,6 +362,14 @@ private[sql] object JsonRDD extends Logging { } } + private def toTimestamp(value: Any): Timestamp = { + value match { + case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) + case value: java.lang.Long => new Timestamp(value) + case value: java.lang.String => Timestamp.valueOf(value) + } + } + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ if (value == null) { null @@ -377,6 +386,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case TimestampType => toTimestamp(value) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 685e788207725..3cfcb2b1aa993 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ +import java.sql.Timestamp + class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -50,6 +52,12 @@ class JsonSuite extends QueryTest { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(new Timestamp(intNumber.toLong), + enforceCorrectType(intNumber.toLong, TimestampType)) + val strDate = "2014-09-30 12:34:56" + checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType)) } test("Get compatible type") { From 1faa1135a3fc0acd89f934f01a4a2edefcb93d33 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 9 Oct 2014 14:50:36 -0700 Subject: [PATCH 12/38] Revert "[SPARK-2805] Upgrade to akka 2.3.4" This reverts commit b9df8af62e8d7b263a668dfb6e9668ab4294ea37. --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 2 +- .../spark/streaming/InputStreamsSuite.scala | 71 +++++++++++++++++++ 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index f2687ce6b42b4..065ddda50e65e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 98a93d1fcb2a3..32790053a6be8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b618..6d0d0bbe5ecec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index cbc0bd178d894..1fef79ad1001f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index 3b6d4ecbae2c1..7756c89b00cad 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.3.4-spark + 2.2.3-shaded-protobuf 1.7.5 1.2.17 1.0.4 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 6107fcdc447b6..952a74fd5f6de 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.streaming import akka.actor.Actor +import akka.actor.IO +import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -142,6 +144,59 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } + // TODO: This test works in IntelliJ but not through SBT + ignore("actor input stream") { + // Start the server + val testServer = new TestServer() + val port = testServer.port + testServer.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", + // Had to pass the local value of port to prevent from closing over entire scope + StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = 1 to 9 + val expectedOutput = input.map(x => x.toString) + Thread.sleep(1000) + for (i <- 0 until input.size) { + testServer.send(input(i).toString) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } + + test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -323,6 +378,22 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } +/** This is an actor for testing actor input stream */ +class TestActor(port: Int) extends Actor with ActorHelper { + + def bytesToString(byteString: ByteString) = byteString.utf8String + + override def preStart(): Unit = { + @deprecated("suppress compile time deprecation warning", "1.0.0") + val unit = IOManager(context.system).connect(new InetSocketAddress(port)) + } + + def receive = { + case IO.Read(socket, bytes) => + store(bytesToString(bytes)) + } +} + /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 1c7f0ab302de9f82b1bd6da852d133823bc67c66 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 9 Oct 2014 14:57:27 -0700 Subject: [PATCH 13/38] [SPARK-3339][SQL] Support for skipping json lines that fail to parse This PR aims to provide a way to skip/query corrupt JSON records. To do so, we introduce an internal column to hold corrupt records (the default name is `_corrupt_record`. This name can be changed by setting the value of `spark.sql.columnNameOfCorruptRecord`). When there is a parsing error, we will put the corrupt record in its unparsed format to the internal column. Users can skip/query this column through SQL. * To query those corrupt records ``` -- For Hive parser SELECT `_corrupt_record` FROM jsonTable WHERE `_corrupt_record` IS NOT NULL -- For our SQL parser SELECT _corrupt_record FROM jsonTable WHERE _corrupt_record IS NOT NULL ``` * To skip corrupt records and query regular records ``` -- For Hive parser SELECT field1, field2 FROM jsonTable WHERE `_corrupt_record` IS NULL -- For our SQL parser SELECT field1, field2 FROM jsonTable WHERE _corrupt_record IS NULL ``` Generally, it is not recommended to change the name of the internal column. If the name has to be changed to avoid possible name conflicts, you can use `sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, )` or `sqlContext.sql(SET spark.sql.columnNameOfCorruptRecord=)`. Author: Yin Huai Closes #2680 from yhuai/corruptJsonRecord and squashes the following commits: 4c9828e [Yin Huai] Merge remote-tracking branch 'upstream/master' into corruptJsonRecord 309616a [Yin Huai] Change the default name of corrupt record to "_corrupt_record". b4a3632 [Yin Huai] Merge remote-tracking branch 'upstream/master' into corruptJsonRecord 9375ae9 [Yin Huai] Set the column name of corrupt json record back to the default one after the unit test. ee584c0 [Yin Huai] Provide a way to query corrupt json records as unparsed strings. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 ++ .../org/apache/spark/sql/SQLContext.scala | 14 +++-- .../spark/sql/api/java/JavaSQLContext.scala | 16 +++-- .../org/apache/spark/sql/json/JsonRDD.scala | 30 ++++++--- .../org/apache/spark/sql/json/JsonSuite.scala | 62 ++++++++++++++++++- .../apache/spark/sql/json/TestJsonData.scala | 9 +++ 6 files changed, 116 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f6f4cf3b80d41..07e6e2eccddf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -35,6 +35,7 @@ private[spark] object SQLConf { val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -131,6 +132,9 @@ private[sql] trait SQLConf { private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def columnNameOfCorruptRecord: String = + getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 35561cac3e5e1..014e1e2826724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -195,9 +195,12 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord val appliedSchema = - Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + Option(schema).getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } @@ -206,8 +209,11 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord + val appliedSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index c006c4330ff66..f8171c3be3207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -148,8 +148,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * It goes through the entire dataset once to determine the schema. */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { - val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord + val appliedScalaSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord)) + val scalaRowRDD = + JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) @@ -162,10 +166,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { */ @Experimental def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( - JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType] + val scalaRowRDD = JsonRDD.jsonStringToRow( + json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index fbc2965e61e92..61ee960aad9d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import java.sql.Timestamp +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -35,16 +36,19 @@ private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( json: RDD[String], - schema: StructType): RDD[Row] = { - parseJson(json).map(parsed => asRow(parsed, schema)) + schema: StructType, + columnNameOfCorruptRecords: String): RDD[Row] = { + parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) } private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): StructType = { + samplingRatio: Double = 1.0, + columnNameOfCorruptRecords: String): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) + val allKeys = + parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) createSchema(allKeys) } @@ -274,7 +278,9 @@ private[sql] object JsonRDD extends Logging { case atom => atom } - private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + private def parseJson( + json: RDD[String], + columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], // ObjectMapper will not return BigDecimal when // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled @@ -289,12 +295,16 @@ private[sql] object JsonRDD extends Logging { // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() iter.flatMap { record => - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - } + try { + val parsed = mapper.readValue(record, classOf[Object]) match { + case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil + case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + } - parsed + parsed + } catch { + case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + } } }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 3cfcb2b1aa993..7bb08f1b513ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import java.sql.Timestamp @@ -644,7 +646,65 @@ class JsonSuite extends QueryTest { ("str_a_1", null, null) :: ("str_a_2", null, null) :: (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") ::Nil + ("str_a_4", "str_b_4", "str_c_4") :: Nil ) } + + test("Corrupt records") { + // Test if we can query corrupt records. + val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val jsonSchemaRDD = jsonRDD(corruptRecords) + jsonSchemaRDD.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonSchemaRDD.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + (null, null, null, "{") :: + (null, null, null, "") :: + (null, null, null, """{"a":1, b:2}""") :: + (null, null, null, """{"a":{, b:3}""") :: + ("str_a_4", "str_b_4", "str_c_4", null) :: + (null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + ("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Seq("{") :: + Seq("") :: + Seq("""{"a":1, b:2}""") :: + Seq("""{"a":{, b:3}""") :: + Seq("]") :: Nil + ) + + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index fc833b8b54e4c..eaca9f0508a12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -143,4 +143,13 @@ object TestJsonData { """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) + + val corruptRecords = + TestSQLContext.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a":1, b:2}""" :: + """{"a":{, b:3}""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """]""" :: Nil) } From 0c0e09f567deb775ee378f5385a16884f68b332d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 9 Oct 2014 14:59:03 -0700 Subject: [PATCH 14/38] [SPARK-3412][SQL]add missing row api chenghao-intel assigned this to me, check PR #2284 for previous discussion Author: Daoyuan Wang Closes #2529 from adrian-wang/rowapi and squashes the following commits: c6594b2 [Daoyuan Wang] using boxed 7b7e6e3 [Daoyuan Wang] update pattern match 7a39456 [Daoyuan Wang] rename file and refresh getAs[T] 4c18c29 [Daoyuan Wang] remove setAs[T] and null judge 1614493 [Daoyuan Wang] add missing row api --- .../sql/catalyst/expressions/Projection.scala | 15 ++++++++++++++ .../spark/sql/catalyst/expressions/Row.scala | 20 ++++++++++--------- ...ificRow.scala => SpecificMutableRow.scala} | 8 ++++++-- 3 files changed, 32 insertions(+), 11 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{SpecificRow.scala => SpecificMutableRow.scala} (97%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index ef1d12531f109..204904ecf04db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -137,6 +137,9 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -226,6 +229,9 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -309,6 +315,9 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -392,6 +401,9 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -475,6 +487,9 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d68a4fabeac77..d00ec39774c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = s"[${this.mkString(",")}]" @@ -118,6 +119,7 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this } @@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } - override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } - override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } override def copy() = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 9cbab3d5d0d0d..570379c533e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def iterator: Iterator[Any] = values.map(_.boxed).iterator - def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String) = update(ordinal, value) - def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] @@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } + + override def getAs[T](i: Int): T = { + values(i).boxed.asInstanceOf[T] + } } From bc3b6cb06153d6b05f311dd78459768b6cf6a404 Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Thu, 9 Oct 2014 15:03:01 -0700 Subject: [PATCH 15/38] [SPARK-3858][SQL] Pass the generator alias into logical plan node The alias parameter is being ignored, which makes it more difficult to specify a qualifier for Generator expressions. Author: Nathan Howell Closes #2721 from NathanHowell/SPARK-3858 and squashes the following commits: 8aa0f43 [Nathan Howell] [SPARK-3858][SQL] Pass the generator alias into logical plan node --- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 2 +- .../test/scala/org/apache/spark/sql/DslQuerySuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 594bf8ffc20e1..948122d42f0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -360,7 +360,7 @@ class SchemaRDD( join: Boolean = false, outer: Boolean = false, alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) + new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) /** * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index d001abb7e1fcc..45e58afe9d9a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -147,6 +147,14 @@ class DslQuerySuite extends QueryTest { (1, 1, 1, 2) :: Nil) } + test("SPARK-3858 generator qualifiers are discarded") { + checkAnswer( + arrayData.as('ad) + .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) + .select("ex.data".attr), + Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + } + test("average") { checkAnswer( testData2.groupBy()(avg('a)), From ac302052870a650d56f2d3131c27755bb2960ad7 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 9 Oct 2014 15:14:58 -0700 Subject: [PATCH 16/38] [SPARK-3813][SQL] Support "case when" conditional functions in Spark SQL. "case when" conditional function is already supported in Spark SQL but there is no support in SqlParser. So added parser support to it. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2678 from ravipesala/SPARK-3813 and squashes the following commits: 70c75a7 [ravipesala] Fixed styles 713ea84 [ravipesala] Updated as per admin comments 709684f [ravipesala] Changed parser to support case when function. --- .../org/apache/spark/sql/catalyst/SqlParser.scala | 14 ++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 15 +++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 854b5b461bdc8..4662f585cfe15 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -77,10 +77,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CACHE = Keyword("CACHE") + protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val ELSE = Keyword("ELSE") + protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") @@ -122,11 +125,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") + protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") + protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") // Use reflection to find the reserved words defined in this class. @@ -333,6 +338,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers { IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | + CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { + case we ~ te => + Seq(casePart.fold(we)(EqualTo(_, we)), te) + } + CaseWhen(altExprs ++ elsePart.toList) + } | (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) } | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b9b196ea5a46a..79de1bb855dbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + } + + test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { + checkAnswer( + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + } + + test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { + checkAnswer( + sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + } } From 4e9b551a0b807f5a2cc6679165c8be4e88a3d077 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 9 Oct 2014 16:08:07 -0700 Subject: [PATCH 17/38] [SPARK-3772] Allow `ipython` to be used by Pyspark workers; IPython support improvements: This pull request addresses a few issues related to PySpark's IPython support: - Fix the remaining uses of the '-u' flag, which IPython doesn't support (see SPARK-3772). - Change PYSPARK_PYTHON_OPTS to PYSPARK_DRIVER_PYTHON_OPTS, so that the old name is reserved in case we ever want to allow the worker Python options to be customized (this variable was introduced in #2554 and hasn't landed in a release yet, so this doesn't break any compatibility). - Introduce a PYSPARK_DRIVER_PYTHON option that allows the driver to use `ipython` while the workers use a different Python version. - Attempt to use Python 2.7 by default if PYSPARK_PYTHON is not specified. - Retain the old semantics for IPYTHON=1 and IPYTHON_OPTS (to avoid breaking existing example programs). There are more details in a block comment in `bin/pyspark`. Author: Josh Rosen Closes #2651 from JoshRosen/SPARK-3772 and squashes the following commits: 7b8eb86 [Josh Rosen] More changes to PySpark python executable configuration: c4f5778 [Josh Rosen] [SPARK-3772] Allow ipython to be used by Pyspark workers; IPython fixes: --- bin/pyspark | 51 ++++++++++++++----- .../api/python/PythonWorkerFactory.scala | 8 ++- .../apache/spark/deploy/PythonRunner.scala | 4 +- docs/programming-guide.md | 8 +-- 4 files changed, 51 insertions(+), 20 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 6655725ef8e8e..96f30a260a09e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -50,22 +50,47 @@ fi . "$FWDIR"/bin/load-spark-env.sh -# Figure out which Python executable to use +# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` +# executable, while the worker would still be launched using PYSPARK_PYTHON. +# +# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added +# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver. +# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set +# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver +# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython +# and executor Python executables. +# +# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables. + +# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set: +if hash python2.7 2>/dev/null; then + # Attempt to use Python 2.7, if installed: + DEFAULT_PYTHON="python2.7" +else + DEFAULT_PYTHON="python" +fi + +# Determine the Python executable to use for the driver: +if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then + # If IPython options are specified, assume user wants to run IPython + # (for backwards-compatibility) + PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS" + PYSPARK_DRIVER_PYTHON="ipython" +elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" +fi + +# Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON="ipython" + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then + echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 + exit 1 else - PYSPARK_PYTHON="python" + PYSPARK_PYTHON="$DEFAULT_PYTHON" fi fi export PYSPARK_PYTHON -if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" -fi - # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -93,9 +118,9 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_PYTHON" -m doctest $1 + exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else - exec "$PYSPARK_PYTHON" $1 + exec "$PYSPARK_DRIVER_PYTHON" $1 fi exit fi @@ -111,5 +136,5 @@ if [[ "$1" =~ \.py$ ]]; then else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 - exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS + exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 71bdf0fe1b917..e314408c067e9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") val worker = pb.start() // Redirect worker stdout and stderr @@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 79b4d7ea41a33..af94b05ce3847 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -34,7 +34,8 @@ object PythonRunner { val pythonFile = args(0) val pyFiles = args(1) val otherArgs = args.slice(2, args.length) - val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python")) // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) @@ -57,6 +58,7 @@ object PythonRunner { val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 8e8cc1dd983f8..18420afb27e3c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: +use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ PYSPARK_PYTHON=ipython ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %} From 2837bf8548db7e9d43f6eefedf5a73feb22daedb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 9 Oct 2014 17:54:02 -0700 Subject: [PATCH 18/38] [SPARK-3798][SQL] Store the output of a generator in a val This prevents it from changing during serialization, leading to corrupted results. Author: Michael Armbrust Closes #2656 from marmbrus/generateBug and squashes the following commits: efa32eb [Michael Armbrust] Store the output of a generator in a val. This prevents it from changing during serialization. --- .../main/scala/org/apache/spark/sql/execution/Generate.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index c386fd121c5de..38877c28de3a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -39,7 +39,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { + // This must be a val since the generator output expr ids are not preserved by serialization. + protected val generatorOutput: Seq[Attribute] = { if (join && outer) { generator.output.map(_.withNullability(true)) } else { @@ -62,7 +63,7 @@ case class Generate( newProjection(child.output ++ nullValues, child.output) val joinProjection = - newProjection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) val joinedRow = new JoinedRow iter.flatMap {row => From 363baacaded56047bcc63276d729ab911e0336cf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 9 Oct 2014 18:21:59 -0700 Subject: [PATCH 19/38] SPARK-3811 [CORE] More robust / standard Utils.deleteRecursively, Utils.createTempDir I noticed a few issues with how temp directories are created and deleted: *Minor* * Guava's `Files.createTempDir()` plus `File.deleteOnExit()` is used in many tests to make a temp dir, but `Utils.createTempDir()` seems to be the standard Spark mechanism * Call to `File.deleteOnExit()` could be pushed into `Utils.createTempDir()` as well, along with this replacement * _I messed up the message in an exception in `Utils` in SPARK-3794; fixed here_ *Bit Less Minor* * `Utils.deleteRecursively()` fails immediately if any `IOException` occurs, instead of trying to delete any remaining files and subdirectories. I've observed this leave temp dirs around. I suggest changing it to continue in the face of an exception and throw one of the possibly several exceptions that occur at the end. * `Utils.createTempDir()` will add a JVM shutdown hook every time the method is called. Even if the subdir is the parent of another parent dir, since this check is inside the hook. However `Utils` manages a set of all dirs to delete on shutdown already, called `shutdownDeletePaths`. A single hook can be registered to delete all of these on exit. This is how Tachyon temp paths are cleaned up in `TachyonBlockManager`. I noticed a few other things that might be changed but wanted to ask first: * Shouldn't the set of dirs to delete be `File`, not just `String` paths? * `Utils` manages the set of `TachyonFile` that have been registered for deletion, but the shutdown hook is managed in `TachyonBlockManager`. Should this logic not live together, and not in `Utils`? it's more specific to Tachyon, and looks a slight bit odd to import in such a generic place. Author: Sean Owen Closes #2670 from srowen/SPARK-3811 and squashes the following commits: 071ae60 [Sean Owen] Update per @vanzin's review da0146d [Sean Owen] Make Utils.deleteRecursively try to delete all paths even when an exception occurs; use one shutdown hook instead of one per method call to delete temp dirs 3a0faa4 [Sean Owen] Standardize on Utils.createTempDir instead of Files.createTempDir --- .../scala/org/apache/spark/TestUtils.scala | 5 +- .../scala/org/apache/spark/util/Utils.scala | 55 +++++++++++++------ .../org/apache/spark/FileServerSuite.scala | 4 +- .../scala/org/apache/spark/FileSuite.scala | 4 +- .../spark/deploy/SparkSubmitSuite.scala | 3 +- .../WholeTextFileRecordReaderSuite.scala | 6 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 21 ++++--- .../scheduler/EventLoggingListenerSuite.scala | 4 +- .../spark/scheduler/ReplayListenerSuite.scala | 4 +- .../spark/storage/DiskBlockManagerSuite.scala | 17 +----- .../apache/spark/util/FileLoggerSuite.scala | 3 +- .../org/apache/spark/util/UtilsSuite.scala | 28 +++++++++- .../spark/mllib/util/MLUtilsSuite.scala | 9 ++- .../spark/repl/ExecutorClassLoaderSuite.scala | 8 +-- .../org/apache/spark/repl/ReplSuite.scala | 4 +- .../spark/streaming/CheckpointSuite.scala | 3 +- .../spark/streaming/InputStreamsSuite.scala | 3 +- .../spark/streaming/MasterFailureTest.scala | 3 +- .../spark/streaming/TestSuiteBase.scala | 5 +- .../spark/deploy/yarn/ClientBaseSuite.scala | 5 +- 20 files changed, 102 insertions(+), 92 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 8ca731038e528..e72826dc25f41 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -26,6 +26,8 @@ import scala.collection.JavaConversions._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import com.google.common.io.Files +import org.apache.spark.util.Utils + /** * Utilities for tests. Included in main codebase since it's used by multiple * projects. @@ -42,8 +44,7 @@ private[spark] object TestUtils { * in order to avoid interference between tests. */ def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) createJar(files, jarFile) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3d307b3c16d3e..07477dd460a4b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -168,6 +168,20 @@ private[spark] object Utils extends Logging { private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + // Add a shutdown hook to delete the temp dirs when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { + override def run(): Unit = Utils.logUncaughtExceptions { + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + }) + // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -252,14 +266,6 @@ private[spark] object Utils extends Logging { } registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) dir } @@ -666,15 +672,30 @@ private[spark] object Utils extends Logging { */ def deleteRecursively(file: File) { if (file != null) { - if (file.isDirectory() && !isSymlink(file)) { - for (child <- listFilesSafely(file)) { - deleteRecursively(child) + try { + if (file.isDirectory && !isSymlink(file)) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(file.getAbsolutePath) + } } - } - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } } } } @@ -713,7 +734,7 @@ private[spark] object Utils extends Logging { */ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { if (!dir.isDirectory) { - throw new IllegalArgumentException("$dir is not a directory!") + throw new IllegalArgumentException(s"$dir is not a directory!") } val filesAndDirs = dir.listFiles() val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 7e18f45de7b5b..a8867020e457d 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.jar.{JarEntry, JarOutputStream} -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -41,8 +40,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { override def beforeAll() { super.beforeAll() - tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() + tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4a53d25012ad9..a2b74c4419d46 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, FileWriter} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} @@ -39,8 +38,7 @@ class FileSuite extends FunSuite with LocalSparkContext { override def beforeEach() { super.beforeEach() - tempDir = Files.createTempDir() - tempDir.deleteOnExit() + tempDir = Utils.createTempDir() } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4cba90e8f2afe..1cdf50d5c08c7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite import org.scalatest.Matchers -import com.google.common.io.Files class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { @@ -332,7 +331,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { } def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index d5ebfb3f3fae1..12d1c7b2faba6 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,8 +23,6 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import com.google.common.io.Files - import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite @@ -66,9 +64,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { * 3) Does the contents be the same. */ test("Correctness of WholeTextFileRecordReader.") { - - val dir = Files.createTempDir() - dir.deleteOnExit() + val dir = Utils.createTempDir() println(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 75b01191901b8..3620e251cc139 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -24,13 +24,14 @@ import org.apache.hadoop.util.Progressable import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random -import com.google.common.io.Files import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.spark.{Partitioner, SharedSparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + import org.scalatest.FunSuite class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { @@ -381,14 +382,16 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - emptyDir.deleteOnExit() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - emptyDir.delete() + val emptyDir = Utils.createTempDir() + try { + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.isEmpty) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } finally { + Utils.deleteRecursively(emptyDir) + } } test("keys and values") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 3efa85431876b..abc300fcffaf9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import scala.collection.mutable import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -51,8 +50,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { private var logDirPath: Path = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "spark-events") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 48114feee6233..e05f373392d4a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} -import com.google.common.io.Files import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -39,8 +38,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { private var testDir: File = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() } after { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index e4522e00a622d..bc5c74c126b74 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,22 +19,13 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.shuffle.hash.HashShuffleManager - -import scala.collection.mutable import scala.language.reflectiveCalls -import akka.actor.Props -import com.google.common.io.Files import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -48,10 +39,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before override def beforeAll() { super.beforeAll() - rootDir0 = Files.createTempDir() - rootDir0.deleteOnExit() - rootDir1 = Files.createTempDir() - rootDir1.deleteOnExit() + rootDir0 = Utils.createTempDir() + rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index c3dd156b40514..dc2a05631d83d 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, IOException} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfter, FunSuite} @@ -44,7 +43,7 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { private var logDirPathString: String = _ before { - testDir = Files.createTempDir() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "test-file-logger") logDirPathString = logDirPath.toString } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index e63d9d085e385..0344da60dae66 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -112,7 +112,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() + val tmpDir2 = Utils.createTempDir() tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) @@ -141,7 +141,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes across multiple files") { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) Files.write("0123456789", files(0), Charsets.UTF_8) @@ -308,4 +308,28 @@ class UtilsSuite extends FunSuite { } } + test("deleteRecursively") { + val tempDir1 = Utils.createTempDir() + assert(tempDir1.exists()) + Utils.deleteRecursively(tempDir1) + assert(!tempDir1.exists()) + + val tempDir2 = Utils.createTempDir() + val tempFile1 = new File(tempDir2, "foo.txt") + Files.touch(tempFile1) + assert(tempFile1.exists()) + Utils.deleteRecursively(tempFile1) + assert(!tempFile1.exists()) + + val tempDir3 = new File(tempDir2, "subdir") + assert(tempDir3.mkdir()) + val tempFile2 = new File(tempDir3, "bar.txt") + Files.touch(tempFile2) + assert(tempFile2.exists()) + Utils.deleteRecursively(tempDir2) + assert(!tempDir2.exists()) + assert(!tempDir3.exists()) + assert(!tempFile2.exists()) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 8ef2bb1bf6a78..0dbe766b4d917 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString @@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) val lines = outputDir.listFiles() @@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Vectors.sparse(2, Array(1), Array(-1.0)), Vectors.dense(0.0, 1.0) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "vectors") val path = outputDir.toURI.toString vectors.saveAsTextFile(path) @@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))), LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "points") val path = outputDir.toURI.toString points.saveAsTextFile(path) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 3e2ee7541f40d..6a79e76a34db8 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -23,8 +23,6 @@ import java.net.{URL, URLClassLoader} import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import com.google.common.io.Files - import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.util.Utils @@ -39,10 +37,8 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { super.beforeAll() - tempDir1 = Files.createTempDir() - tempDir1.deleteOnExit() - tempDir2 = Files.createTempDir() - tempDir2.deleteOnExit() + tempDir1 = Utils.createTempDir() + tempDir2 = Utils.createTempDir() url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c8763eb277052..91c9c52c3c98a 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext import org.apache.commons.lang3.StringEscapeUtils @@ -190,8 +189,7 @@ class ReplSuite extends FunSuite { } test("interacting with files") { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val out = new FileWriter(tempDir + "/input") out.write("Hello world!\n") out.write("What's up?\n") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8511390cb1ad5..e5592e52b0d2d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -231,8 +231,7 @@ class CheckpointSuite extends TestSuiteBase { // failure, are re-processed or not. test("recovery with file input stream") { // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() var ssc = new StreamingContext(master, framework, Seconds(1)) ssc.checkpoint(checkpointDir) val fileStream = ssc.textFileStream(testDir.toString) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..a44a45a3e9bd6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -98,8 +98,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() val ssc = new StreamingContext(conf, batchDuration) val fileStream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index c53c01706083a..5dbb7232009eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -352,8 +352,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) extends Thread with Logging { override def run() { - val localTestDir = Files.createTempDir() - localTestDir.deleteOnExit() + val localTestDir = Utils.createTempDir() var fs = testDir.getFileSystem(new Configuration()) val maxTries = 3 try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 759baacaa4308..9327ff4822699 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, FunSuite} -import com.google.common.io.Files import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -120,9 +120,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Directory where the checkpoint data will be saved lazy val checkpointDir = { - val dir = Files.createTempDir() + val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") - dir.deleteOnExit() dir.toString } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 9bd916100dd2c..17b79ae1d82c4 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -20,13 +20,10 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ @@ -117,7 +114,7 @@ class ClientBaseSuite extends FunSuite with Matchers { doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort(), anyBoolean()) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath()) sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) From edf02da389f75df5a42465d41f035d6b65599848 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Oct 2014 18:25:06 -0700 Subject: [PATCH 20/38] [SPARK-3654][SQL] Unifies SQL and HiveQL parsers This PR is a follow up of #2590, and tries to introduce a top level SQL parser entry point for all SQL dialects supported by Spark SQL. A top level parser `SparkSQLParser` is introduced to handle the syntaxes that all SQL dialects should recognize (e.g. `CACHE TABLE`, `UNCACHE TABLE` and `SET`, etc.). For all the syntaxes this parser doesn't recognize directly, it fallbacks to a specified function that tries to parse arbitrary input to a `LogicalPlan`. This function is typically another parser combinator like `SqlParser`. DDL syntaxes introduced in #2475 can be moved to here. The `ExtendedHiveQlParser` now only handle Hive specific extensions. Also took the chance to refactor/reformat `SqlParser` for better readability. Author: Cheng Lian Closes #2698 from liancheng/gen-sql-parser and squashes the following commits: ceada76 [Cheng Lian] Minor styling fixes 9738934 [Cheng Lian] Minor refactoring, removes optional trailing ";" in the parser bb2ab12 [Cheng Lian] SET property value can be empty string ce8860b [Cheng Lian] Passes test suites e86968e [Cheng Lian] Removes debugging code 8bcace5 [Cheng Lian] Replaces digit.+ to rep1(digit) (Scala style checking doesn't like it) d15d54f [Cheng Lian] Unifies SQL and HiveQL parsers --- .../spark/sql/catalyst/SparkSQLParser.scala | 186 ++++++++ .../apache/spark/sql/catalyst/SqlParser.scala | 426 +++++++----------- .../sql/catalyst/plans/logical/commands.scala | 15 +- .../org/apache/spark/sql/SQLContext.scala | 9 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../apache/spark/sql/execution/commands.scala | 34 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../server/SparkSQLOperationManager.scala | 2 +- .../spark/sql/hive/ExtendedHiveQlParser.scala | 110 +---- .../org/apache/spark/sql/hive/HiveQl.scala | 15 +- .../spark/sql/hive/HiveStrategies.scala | 8 +- 12 files changed, 414 insertions(+), 401 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala new file mode 100644 index 0000000000000..04467342e6ab5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.input.CharArrayReader.EofCh + +import org.apache.spark.sql.catalyst.plans.logical._ + +private[sql] abstract class AbstractSparkSQLParser + extends StandardTokenParsers with PackratParsers { + + def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } + + protected case class Keyword(str: String) + + protected def start: Parser[LogicalPlan] + + // Returns the whole input string + protected lazy val wholeInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success(in.source.toString, in.drop(in.source.length())) + } + + // Returns the rest of the input string that are not parsed yet + protected lazy val restInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success( + in.source.subSequence(in.offset, in.source.length()).toString, + in.drop(in.source.length())) + } +} + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", "." + ) + + override lazy val token: Parser[Token] = + ( identChar ~ (identChar | digit).* ^^ + { case first ~ rest => processIdent((first :: rest).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} + +/** + * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL + * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. + * + * @param fallback A function that parses an input string to a logical plan + */ +private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { + + // A parser for the key-value part of the "SET [key = [value ]]" syntax + private object SetCommandParser extends RegexParsers { + private val key: Parser[String] = "(?m)[^=]+".r + + private val value: Parser[String] = "(?m).*$".r + + private val pair: Parser[LogicalPlan] = + (key ~ ("=".r ~> value).?).? ^^ { + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) + } + + def apply(input: String): LogicalPlan = parseAll(pair, input) match { + case Success(plan, _) => plan + case x => sys.error(x.toString) + } + } + + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val TABLE = Keyword("TABLE") + protected val SOURCE = Keyword("SOURCE") + protected val UNCACHE = Keyword("UNCACHE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + private val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + override protected lazy val start: Parser[LogicalPlan] = + cache | uncache | set | shell | source | others + + private lazy val cache: Parser[LogicalPlan] = + CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + } + + private lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + + private lazy val set: Parser[LogicalPlan] = + SET ~> restInput ^^ { + case input => SetCommandParser(input) + } + + private lazy val shell: Parser[LogicalPlan] = + "!" ~> restInput ^^ { + case input => ShellCommand(input.trim) + } + + private lazy val source: Parser[LogicalPlan] = + SOURCE ~> restInput ^^ { + case input => SourceCommand(input.trim) + } + + private lazy val others: Parser[LogicalPlan] = + wholeInput ^^ { + case input => fallback(input) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 4662f585cfe15..b4d606d37e732 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -18,10 +18,6 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -39,31 +35,7 @@ import org.apache.spark.sql.catalyst.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - +class SqlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) @@ -100,7 +72,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") protected val LAST = Keyword("LAST") - protected val LAZY = Keyword("LAZY") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") @@ -128,7 +99,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") - protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") @@ -136,7 +106,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { // Use reflection to find the reserved words defined in this class. protected val reservedWords = - this.getClass + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) @@ -150,86 +121,68 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } } - protected lazy val query: Parser[LogicalPlan] = ( - select * ( - UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | - INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } | - EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | - UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + protected lazy val start: Parser[LogicalPlan] = + ( select * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache | unCache - ) + | insert + ) protected lazy val select: Parser[LogicalPlan] = - SELECT ~> opt(DISTINCT) ~ projections ~ - opt(from) ~ opt(filter) ~ - opt(grouping) ~ - opt(having) ~ - opt(orderBy) ~ - opt(limit) <~ opt(";") ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) - val withFilter = f.map(f => Filter(f, base)).getOrElse(base) - val withProjection = - g.map {g => - Aggregate(g, assignAliases(p), withFilter) - }.getOrElse(Project(assignAliases(p), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) - val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder) - withLimit - } + SELECT ~> DISTINCT.? ~ + repsep(projection, ",") ~ + (FROM ~> relations).? ~ + (WHERE ~> expression).? ~ + (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ + (HAVING ~> expression).? ~ + (ORDER ~ BY ~> ordering).? ~ + (LIMIT ~> expression).? ^^ { + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + val base = r.getOrElse(NoRelation) + val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withProjection = g + .map(Aggregate(_, assignAliases(p), withFilter)) + .getOrElse(Project(assignAliases(p), withFilter)) + val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) + val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) + val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving) + val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) + withLimit + } protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ { - case o ~ r ~ s => - val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" - InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val unCache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { - case tableName => UncacheTableCommand(tableName) + INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ { + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined) } - protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") - protected lazy val projection: Parser[Expression] = - expression ~ (opt(AS) ~> opt(ident)) ^^ { - case e ~ None => e - case e ~ Some(a) => Alias(e, a)() + expression ~ (AS.? ~> ident.?) ^^ { + case e ~ a => a.fold(e)(Alias(e, _)()) } - protected lazy val from: Parser[LogicalPlan] = FROM ~> relations - - protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation - // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = - relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } | - relation + ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) } + | relation + ) protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | - relationFactor + joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ident ~ (opt(AS) ~> opt(ident)) ^^ { - case tableName ~ alias => UnresolvedRelation(None, tableName, alias) - } | - "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } + ( ident ~ (opt(AS) ~> opt(ident)) ^^ { + case tableName ~ alias => UnresolvedRelation(None, tableName, alias) + } + | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } + ) protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { - case r1 ~ jt ~ _ ~ r2 ~ cond => + relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ { + case r1 ~ jt ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) } @@ -237,160 +190,145 @@ class SqlParser extends StandardTokenParsers with PackratParsers { ON ~> expression protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter - - protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } - - protected lazy val orderBy: Parser[Seq[SortOrder]] = - ORDER ~> BY ~> ordering + ( INNER ^^^ Inner + | LEFT ~ SEMI ^^^ LeftSemi + | LEFT ~ OUTER.? ^^^ LeftOuter + | RIGHT ~ OUTER.? ^^^ RightOuter + | FULL ~ OUTER.? ^^^ FullOuter + ) protected lazy val ordering: Parser[Seq[SortOrder]] = - rep1sep(singleOrder, ",") | - rep1sep(expression, ",") ~ opt(direction) ^^ { - case exps ~ None => exps.map(SortOrder(_, Ascending)) - case exps ~ Some(d) => exps.map(SortOrder(_, d)) - } + ( rep1sep(singleOrder, ",") + | rep1sep(expression, ",") ~ direction.? ^^ { + case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending))) + } + ) protected lazy val singleOrder: Parser[SortOrder] = - expression ~ direction ^^ { case e ~ o => SortOrder(e,o) } + expression ~ direction ^^ { case e ~ o => SortOrder(e, o) } protected lazy val direction: Parser[SortDirection] = - ASC ^^^ Ascending | - DESC ^^^ Descending - - protected lazy val grouping: Parser[Seq[Expression]] = - GROUP ~> BY ~> rep1sep(expression, ",") - - protected lazy val having: Parser[Expression] = - HAVING ~> expression - - protected lazy val limit: Parser[Expression] = - LIMIT ~> expression + ( ASC ^^^ Ascending + | DESC ^^^ Descending + ) - protected lazy val expression: Parser[Expression] = orExpression + protected lazy val expression: Parser[Expression] = + orExpression protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) }) + andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) protected lazy val andExpression: Parser[Expression] = - comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) + comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | - termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | - termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | - termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | - termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { - case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - } | - termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | - termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ e2 => In(e1, e2) - } | - termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2)) - } | - termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | - termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } | - NOT ~> termExpression ^^ {e => Not(e)} | - termExpression + ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } + | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } + | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } + | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } + | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } + | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { + case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + } + | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => In(e1, e2) + } + | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => Not(In(e1, e2)) + } + | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } + | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } + | NOT ~> termExpression ^^ {e => Not(e)} + | termExpression + ) protected lazy val termExpression: Parser[Expression] = - productExpression * ( - "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } | - "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } ) + productExpression * + ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } + | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } + ) protected lazy val productExpression: Parser[Expression] = - baseExpression * ( - "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | - "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | - "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } - ) + baseExpression * + ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } + | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } + | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } + ) protected lazy val function: Parser[Expression] = - SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } | - SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | - COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | - COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | - COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | - APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { - case exp => ApproxCountDistinct(exp) - } | - APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { - case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) - } | - FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | - LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } | - AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | - MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | - MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | - UPPER ~> "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } | - LOWER ~> "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } | - IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case c ~ "," ~ t ~ "," ~ f => If(c,t,f) - } | - CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ - (ELSE ~> expression).? <~ END ^^ { - case casePart ~ altPart ~ elsePart => - val altExprs = altPart.flatMap { - case we ~ te => - Seq(casePart.fold(we)(EqualTo(_, we)), te) + ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } + | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } + | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } + | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } + | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } + | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ + { case exp => ApproxCountDistinct(exp) } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } + | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } + | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } + | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } + | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } + | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } + | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } + | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } + | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case c ~ t ~ f => If(c, t, f) } + | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { case whenExpr ~ thenExpr => + Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr) + } + CaseWhen(altExprs ++ elsePart.toList) } - CaseWhen(altExprs ++ elsePart.toList) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) - } | - SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | - ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | - ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { - case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) - } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p ~ l => Substring(s, p, l) } + | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } + | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + ) protected lazy val cast: Parser[Expression] = - CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } protected lazy val literal: Parser[Literal] = - numericLit ^^ { - case i if i.toLong > Int.MaxValue => Literal(i.toLong) - case i => Literal(i.toInt) - } | - NULL ^^^ Literal(null, NullType) | - floatLit ^^ {case f => Literal(f.toDouble) } | - stringLit ^^ {case s => Literal(s, StringType) } + ( numericLit ^^ { + case i if i.toLong > Int.MaxValue => Literal(i.toLong) + case i => Literal(i.toInt) + } + | NULL ^^^ Literal(null, NullType) + | floatLit ^^ {case f => Literal(f.toDouble) } + | stringLit ^^ {case s => Literal(s, StringType) } + ) protected lazy val floatLit: Parser[String] = elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - (expression <~ ".") ~ ident ^^ { - case base ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - dotExpressionHeader | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal + ( expression ~ ("[" ~> expression <~ "]") ^^ + { case base ~ ordinal => GetItem(base, ordinal) } + | (expression <~ ".") ~ ident ^^ + { case base ~ fieldName => GetField(base, fieldName) } + | TRUE ^^^ Literal(true, BooleanType) + | FALSE ^^^ Literal(false, BooleanType) + | cast + | "(" ~> expression <~ ")" + | function + | "-" ~> literal ^^ UnaryMinus + | dotExpressionHeader + | ident ^^ UnresolvedAttribute + | "*" ^^^ Star(None) + | literal + ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { @@ -400,55 +338,3 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } - -class SqlLexical(val keywords: Seq[String]) extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - - reserved ++= keywords.flatMap(w => allCaseVersions(w)) - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", "." - ) - - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 9a3848cfc6b62..b8ba2ee428a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -39,9 +39,9 @@ case class NativeCommand(cmd: String) extends Command { } /** - * Commands of the form "SET (key) (= value)". + * Commands of the form "SET [key [= value] ]". */ -case class SetCommand(key: Option[String], value: Option[String]) extends Command { +case class SetCommand(kv: Option[(String, Option[String])]) extends Command { override def output = Seq( AttributeReference("", StringType, nullable = false)()) } @@ -81,3 +81,14 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "! shellCommand" command + */ +case class ShellCommand(cmd: String) extends Command + + +/** + * Returned for the "SOURCE file" command + */ +case class SourceCommand(filePath: String) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 014e1e2826724..23e7b2d270777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -66,12 +66,17 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = true) + @transient protected[sql] val optimizer = Optimizer + @transient - protected[sql] val parser = new catalyst.SqlParser + protected[sql] val sqlParser = { + val fallback = new catalyst.SqlParser + new catalyst.SparkSQLParser(fallback(_)) + } - protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bbf17b9fadf86..4f1af7234d551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -304,8 +304,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.SetCommand(key, value) => - Seq(execution.SetCommand(key, value, plan.output)(context)) + case logical.SetCommand(kv) => + Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheTableCommand(tableName, optPlan, isLazy) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index d49633c24ad4d..5859eba408ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -48,29 +48,28 @@ trait Command { * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - key: Option[String], value: Option[String], output: Seq[Attribute])( +case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected lazy val sideEffectResult: Seq[Row] = (key, value) match { - // Set value for key k. - case (Some(k), Some(v)) => - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + override protected lazy val sideEffectResult: Seq[Row] = kv match { + // Set value for the key. + case Some((key, Some(value))) => + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) } else { - context.setConf(k, v) - Seq(Row(s"$k=$v")) + context.setConf(key, value) + Seq(Row(s"$key=$value")) } - // Query the value bound to key k. - case (Some(k), _) => + // Query the value bound to the key. + case Some((key, None)) => // TODO (lian) This is just a workaround to make the Simba ODBC driver work. // Should remove this once we get the ODBC driver updated. - if (k == "-v") { + if (key == "-v") { val hiveJars = Seq( "hive-exec-0.12.0.jar", "hive-service-0.12.0.jar", @@ -84,23 +83,20 @@ case class SetCommand( Row("system:java.class.path=" + hiveJars), Row("system:sun.java.command=shark.SharkServer2")) } else { - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) } else { - Seq(Row(s"$k=${context.getConf(k, "")}")) + Seq(Row(s"$key=${context.getConf(key, "")}")) } } // Query all key-value pairs that are set in the SQLConf of the context. - case (None, None) => + case _ => context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq - - case _ => - throw new IllegalArgumentException() } override def otherCopyArgs = context :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e624f97004f5..c87ded81fdc27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -69,7 +69,7 @@ class CachedTableSuite extends QueryTest { test("calling .unpersist() should drop in-memory columnar cache") { table("testData").cache() table("testData").count() - table("testData").unpersist(true) + table("testData").unpersist(blocking = true) assertCached(table("testData"), 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 79de1bb855dbe..a94022c0cf6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -42,7 +42,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { TimeZone.setDefault(origZone) } - test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), @@ -61,7 +60,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 4) } - test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), @@ -694,6 +692,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 910174a153768..accf61576b804 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -172,7 +172,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => sessionToActivePool(parentSession) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index c5844e92eaaa9..430ffb29989ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -18,118 +18,50 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers + import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} /** - * A parser that recognizes all HiveQL constructs together with several Spark SQL specific - * extensions like CACHE TABLE and UNCACHE TABLE. + * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ -private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else if (input.trim.startsWith("!")) { - ShellCommand(input.drop(1)) - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - - protected val ADD = Keyword("ADD") - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SOURCE = Keyword("SOURCE") - protected val TABLE = Keyword("TABLE") - protected val UNCACHE = Keyword("UNCACHE") - +private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - protected def allCaseConverse(k: String): Parser[String] = - lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _) + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") + protected val FILE = Keyword("FILE") + protected val JAR = Keyword("JAR") - protected val reservedWords = - this.getClass + private val reservedWords = + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) override val lexical = new SqlLexical(reservedWords) - protected lazy val query: Parser[LogicalPlan] = - cache | uncache | addJar | addFile | dfs | source | hiveQl + protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim()) - } - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success( - in.source.subSequence(in.offset, in.source.length).toString, - in.drop(in.source.length())) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) + case statement => HiveQl.createPlan(statement.trim) } - protected lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case jar => AddJar(jar.trim()) + protected lazy val dfs: Parser[LogicalPlan] = + DFS ~> wholeInput ^^ { + case command => NativeCommand(command.trim) } - protected lazy val addFile: Parser[LogicalPlan] = + private lazy val addFile: Parser[LogicalPlan] = ADD ~ FILE ~> restInput ^^ { - case file => AddFile(file.trim()) + case input => AddFile(input.trim) } - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => NativeCommand(command.trim()) - } - - protected lazy val source: Parser[LogicalPlan] = - SOURCE ~> restInput ^^ { - case file => SourceCommand(file.trim()) + private lazy val addJar: Parser[LogicalPlan] = + ADD ~ JAR ~> restInput ^^ { + case input => AddJar(input.trim) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 32c9175f181bb..98a46a31e1ffd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.spark.sql.catalyst.SparkSQLParser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -38,10 +39,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -private[hive] case class ShellCommand(cmd: String) extends Command - -private[hive] case class SourceCommand(filePath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command private[hive] case class AddJar(path: String) extends Command @@ -126,9 +123,11 @@ private[hive] object HiveQl { "TOK_CREATETABLE", "TOK_DESCTABLE" ) ++ nativeCommands - - // It parses hive sql query along with with several Spark SQL specific extensions - protected val hiveSqlParser = new ExtendedHiveQlParser + + protected val hqlParser = { + val fallback = new ExtendedHiveQlParser + new SparkSQLParser(fallback(_)) + } /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -218,7 +217,7 @@ private[hive] object HiveQl { def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser(sql) /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 508d8239c7628..5c66322f1ed99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -167,10 +167,10 @@ private[hive] trait HiveStrategies { database.get, tableName, query, - InsertIntoHiveTable(_: MetastoreRelation, - Map(), - query, - true)(hiveContext)) :: Nil + InsertIntoHiveTable(_: MetastoreRelation, + Map(), + query, + overwrite = true)(hiveContext)) :: Nil case _ => Nil } } From 421382d0e728940caa3e61bc11237c61f256378a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Oct 2014 18:26:43 -0700 Subject: [PATCH 21/38] [SPARK-3824][SQL] Sets in-memory table default storage level to MEMORY_AND_DISK Using `MEMORY_AND_DISK` as default storage level for in-memory table caching. Due to the in-memory columnar representation, recomputing an in-memory cached table partitions can be very expensive. Author: Cheng Lian Closes #2686 from liancheng/spark-3824 and squashes the following commits: 35d2ed0 [Cheng Lian] Removes extra space 1ab7967 [Cheng Lian] Reduces test data size to fit DiskStore.getBytes() ba565f0 [Cheng Lian] Maks CachedBatch serializable 07f0204 [Cheng Lian] Sets in-memory table default storage level to MEMORY_AND_DISK --- .../main/scala/org/apache/spark/sql/CacheManager.scala | 10 +++++++--- .../spark/sql/columnar/InMemoryColumnarTableScan.scala | 9 +++++---- .../scala/org/apache/spark/sql/CachedTableSuite.scala | 10 +++++----- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3bf7382ac67a6..5ab2b5316ab10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel -import org.apache.spark.storage.StorageLevel.MEMORY_ONLY +import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) @@ -74,10 +74,14 @@ private[sql] trait CacheManager { cachedData.clear() } - /** Caches the data produced by the logical representation of the given schema rdd. */ + /** + * Caches the data produced by the logical representation of the given schema rdd. Unlike + * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing + * the in-memory columnar representation of the underlying table is expensive. + */ private[sql] def cacheQuery( query: SchemaRDD, - storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 4f79173a26f88..22ab0e2613f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -38,7 +38,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } -private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -91,7 +91,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) - CachedBatch(columnBuilders.map(_.build()), stats) + CachedBatch(columnBuilders.map(_.build().array()), stats) } def hasNext = rowIterator.hasNext @@ -238,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan( def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors - val columnAccessors = - requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + val columnAccessors = requestedColumnIndices.map { batch => + ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch))) + } // Extract rows via column accessors new Iterator[Row] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c87ded81fdc27..444bc95009c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -55,10 +55,10 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") - cacheTable("bigData") - assert(table("bigData").count() === 1000000L) - uncacheTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(table("bigData").count() === 200000L) + table("bigData").unpersist() } test("calling .cache() should use in-memory columnar caching") { From 6f98902a3d7749e543bc493a8c62b1e3a7b924cc Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 9 Oct 2014 18:41:36 -0700 Subject: [PATCH 22/38] [SPARK-3834][SQL] Backticks not correctly handled in subquery aliases The queries like SELECT a.key FROM (SELECT key FROM src) \`a\` does not work as backticks in subquery aliases are not handled properly. This PR fixes that. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2737 from ravipesala/SPARK-3834 and squashes the following commits: 0e0ab98 [ravipesala] Fixing issue in backtick handling for subquery aliases --- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../apache/spark/sql/hive/execution/SQLQuerySuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 98a46a31e1ffd..7cc14dc7a9c9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -638,7 +638,7 @@ private[hive] object HiveQl { def nodeToRelation(node: Node): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(alias, nodeToPlan(query)) + Subquery(cleanIdentifier(alias), nodeToPlan(query)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => val Token("TOK_SELECT", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3647bb1c4ce7d..fbe6ac765c009 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -68,5 +68,11 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), sql("SELECT `key` FROM src").collect().toSeq) - } + } + + test("SPARK-3834 Backticks not correctly handled in subquery aliases") { + checkAnswer( + sql("SELECT a.key FROM (SELECT key FROM src) `a`"), + sql("SELECT `key` FROM src").collect().toSeq) + } } From 411cf29fff011561f0093bb6101af87842828369 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Fri, 10 Oct 2014 00:46:56 -0700 Subject: [PATCH 23/38] [SPARK-2805] Upgrade Akka to 2.3.4 This is a second rev of the Akka upgrade (earlier merged, but reverted). I made a slight modification which is that I also upgrade Hive to deal with a compatibility issue related to the protocol buffers library. Author: Anand Avati Author: Patrick Wendell Closes #2752 from pwendell/akka-upgrade and squashes the following commits: 4c7ca3f [Patrick Wendell] Upgrading to new hive->protobuf version 57a2315 [Anand Avati] SPARK-1812: streaming - remove tests which depend on akka.actor.IO 2a551d3 [Anand Avati] SPARK-1812: core - upgrade to akka 2.3.4 --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 4 +- .../spark/streaming/InputStreamsSuite.scala | 71 ------------------- 6 files changed, 7 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index 7756c89b00cad..d047b9e307d4b 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -127,7 +127,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0 + 0.12.0-protobuf 1.4.3 1.2.3 8.1.14.v20131031 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index a44a45a3e9bd6..fa04fa326e370 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -143,59 +141,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -377,22 +322,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 90f73fcc47c7bf881f808653d46a9936f37c3c31 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 10 Oct 2014 01:44:36 -0700 Subject: [PATCH 24/38] [SPARK-3889] Attempt to avoid SIGBUS by not mmapping files in ConnectionManager In general, individual shuffle blocks are frequently small, so mmapping them often creates a lot of waste. It may not be bad to mmap the larger ones, but it is pretty inconvenient to get configuration into ManagedBuffer, and besides it is unlikely to help all that much. Author: Aaron Davidson Closes #2742 from aarondav/mmap and squashes the following commits: a152065 [Aaron Davidson] Add other pathway back 52b6cd2 [Aaron Davidson] [SPARK-3889] Attempt to avoid SIGBUS by not mmapping files in ConnectionManager --- .../org/apache/spark/network/ManagedBuffer.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index a4409181ec907..4c9ca97a2a6b7 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -66,13 +66,27 @@ sealed abstract class ManagedBuffer { final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + override def size: Long = length override def nioByteBuffer(): ByteBuffer = { var channel: FileChannel = null try { channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + val buf = ByteBuffer.allocate(length.toInt) + channel.read(buf, offset) + buf.flip() + buf + } else { + channel.map(MapMode.READ_ONLY, offset, length) + } } catch { case e: IOException => Try(channel.size).toOption match { From 72f36ee571ad27c7c7c70bb9aecc7e6ef51dfd44 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 14:14:05 -0700 Subject: [PATCH 25/38] [SPARK-3886] [PySpark] use AutoBatchedSerializer by default Use AutoBatchedSerializer by default, which will choose the proper batch size based on size of serialized objects, let the size of serialized batch fall in into [64k - 640k]. In JVM, the serializer will also track the objects in batch to figure out duplicated objects, larger batch may cause OOM in JVM. Author: Davies Liu Closes #2740 from davies/batchsize and squashes the following commits: 52cdb88 [Davies Liu] update docs 185f2b9 [Davies Liu] use AutoBatchedSerializer by default --- python/pyspark/context.py | 11 +++++++---- python/pyspark/serializers.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6fb30d65c5edd..85c04624da4a6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -67,7 +67,7 @@ class SparkContext(object): _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None): """ Create a new SparkContext. At least the master and app name should be set, @@ -83,8 +83,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param environment: A dictionary of environment variables to set on worker nodes. :param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. + Java object. Set 1 to disable batching, 0 to automatically choose + the batch size based on object sizes, or -1 to use an unlimited + batch size :param serializer: The serializer for RDDs. :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM @@ -117,6 +118,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._unbatched_serializer = serializer if batchSize == 1: self.serializer = self._unbatched_serializer + elif batchSize == 0: + self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 099fa54cf2bd7..3d1a34b281acc 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,7 +220,7 @@ class AutoBatchedSerializer(BatchedSerializer): Choose the size of batch automatically based on the size of object """ - def __init__(self, serializer, bestSize=1 << 20): + def __init__(self, serializer, bestSize=1 << 16): BatchedSerializer.__init__(self, serializer, -1) self.bestSize = bestSize @@ -247,7 +247,7 @@ def __eq__(self, other): other.serializer == self.serializer) def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer<%s>" % str(self.serializer) class CartesianDeserializer(FramedSerializer): From 1d72a30874a88bdbab75217f001cf2af409016e7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 10 Oct 2014 16:49:19 -0700 Subject: [PATCH 26/38] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. We had to upgrade our Hive 0.12 version as well to deal with a protobuf conflict (both hive and akka have been using a shaded protobuf version). This is testing a correctly patched version of Hive 0.12. Author: Patrick Wendell Closes #2756 from pwendell/hotfix and squashes the following commits: cc979d0 [Patrick Wendell] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. --- pom.xml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index d047b9e307d4b..288bbf1114bea 100644 --- a/pom.xml +++ b/pom.xml @@ -127,7 +127,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0-protobuf + 0.12.0-protobuf-2.5 1.4.3 1.2.3 8.1.14.v20131031 @@ -223,6 +223,18 @@ false + + + spark-staging + Spring Staging Repository + https://oss.sonatype.org/content/repositories/orgspark-project-1085 + + true + + + false + + From 0e8203f4fb721158fb27897680da476174d24c4b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 10 Oct 2014 18:39:55 -0700 Subject: [PATCH 27/38] [SPARK-2924] Required by scala 2.11, only one fun/ctor amongst overriden alternatives, can have default argument(s). ...riden alternatives, can have default argument. Author: Prashant Sharma Closes #2750 from ScrapCodes/SPARK-2924/default-args-removed and squashes the following commits: d9785c3 [Prashant Sharma] [SPARK-2924] Required by scala 2.11, only one function/ctor amongst overriden alternatives, can have default argument. --- .../org/apache/spark/util/FileLogger.scala | 19 +++++++++++++++++-- .../apache/spark/util/FileLoggerSuite.scala | 8 ++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6d1fc05a15d2c..fdc73f08261a6 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -51,12 +51,27 @@ private[spark] class FileLogger( def this( logDir: String, sparkConf: SparkConf, - compress: Boolean = false, - overwrite: Boolean = true) = { + compress: Boolean, + overwrite: Boolean) = { this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, overwrite = overwrite) } + def this( + logDir: String, + sparkConf: SparkConf, + compress: Boolean) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, + overwrite = true) + } + + def this( + logDir: String, + sparkConf: SparkConf) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false, + overwrite = true) + } + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index dc2a05631d83d..72466a3aa1130 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -74,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { test("Logging when directory already exists") { // Create the logging directory multiple times - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() // If overwrite is not enabled, an exception should be thrown intercept[IOException] { - new FileLogger(logDirPathString, new SparkConf, overwrite = false).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start() } } From 81015a2ba49583d730ce65b2262f50f1f2451a79 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Sat, 11 Oct 2014 11:26:17 -0700 Subject: [PATCH 28/38] [SPARK-3867][PySpark] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed ./python/run-tests search a Python 2.6 executable on PATH and use it if available. When using Python 2.6, it is going to import unittest2 module which is not a standard library in Python 2.6, so it fails with ImportError. Author: cocoatomo Closes #2759 from cocoatomo/issues/3867-unittest2-import-error and squashes the following commits: f068eb5 [cocoatomo] [SPARK-3867] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed --- python/pyspark/mllib/tests.py | 6 +++++- python/pyspark/tests.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5c20e100e144f..463faf7b6f520 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -25,7 +25,11 @@ from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f05d48ade2b3..ceab57464f013 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,7 +34,11 @@ from platform import python_implementation if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest From 7a3f589ef86200f99624fea8322e5af0cad774a7 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Sat, 11 Oct 2014 11:51:59 -0700 Subject: [PATCH 29/38] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings Sphinx documents contains a corrupted ReST format and have some warnings. The purpose of this issue is same as https://issues.apache.org/jira/browse/SPARK-3773. commit: 0e8203f4fb721158fb27897680da476174d24c4b output ``` $ cd ./python/docs $ make clean html rm -rf _build/* sphinx-build -b html -d _build/doctrees . _build/html Making output directory... Running Sphinx v1.2.3 loading pickled environment... not yet created building [html]: targets for 4 source files that are out of date updating environment: 4 added, 0 changed, 0 removed reading sources... [100%] pyspark.sql /Users//MyRepos/Scala/spark/python/pyspark/mllib/feature.py:docstring of pyspark.mllib.feature.Word2VecModel.findSynonyms:4: WARNING: Field list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/mllib/feature.py:docstring of pyspark.mllib.feature.Word2VecModel.transform:3: WARNING: Field list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/sql.py:docstring of pyspark.sql:4: WARNING: Bullet list ends without a blank line; unexpected unindent. looking for now-outdated files... none found pickling environment... done checking consistency... done preparing documents... done writing output... [100%] pyspark.sql writing additional files... (12 module code pages) _modules/index search copying static files... WARNING: html_static_path entry u'/Users//MyRepos/Scala/spark/python/docs/_static' does not exist done copying extra files... done dumping search index... done dumping object inventory... done build succeeded, 4 warnings. Build finished. The HTML pages are in _build/html. ``` Author: cocoatomo Closes #2766 from cocoatomo/issues/3909-sphinx-build-warnings and squashes the following commits: 2c7faa8 [cocoatomo] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings --- python/docs/conf.py | 2 +- python/pyspark/mllib/feature.py | 2 ++ python/pyspark/rdd.py | 2 +- python/pyspark/sql.py | 10 +++++----- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/docs/conf.py b/python/docs/conf.py index 8e6324f058251..e58d97ae6a746 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -131,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index a44a27fd3b6a6..f4cbf31b94fe2 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -44,6 +44,7 @@ def transform(self, word): """ :param word: a word :return: vector representation of word + Transforms a word to its vector representation Note: local use only @@ -57,6 +58,7 @@ def findSynonyms(self, x, num): :param x: a word or a vector representation of word :param num: number of synonyms to find :return: array of (word, cosineSimilarity) + Find synonyms of a word Note: local use only diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6797d50659a92..e13bab946c44a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2009,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - :param relativeSD Relative accuracy. Smaller values create + :param relativeSD: Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d3d36eb995ab6..b31a82f9b19ac 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -19,14 +19,14 @@ public classes of Spark SQL: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for SQL functionality. - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. - L{Row} - A Row of data returned by a Spark SQL query. + A Row of data returned by a Spark SQL query. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive.. """ import itertools From 69c67abaa9d4bb4b95792d1862bc65efc764c194 Mon Sep 17 00:00:00 2001 From: giwa Date: Sun, 12 Oct 2014 02:46:56 -0700 Subject: [PATCH 30/38] [SPARK-2377] Python API for Streaming This patch brings Python API for Streaming. This patch is based on work from @giwa Author: giwa Author: Ken Takagiwa Author: Davies Liu Author: Ken Takagiwa Author: Tathagata Das Author: Ken Author: Ken Takagiwa Author: Matthew Farrellee Closes #2538 from davies/streaming and squashes the following commits: 64561e4 [Davies Liu] fix tests 331ecce [Davies Liu] fix example 3e2492b [Davies Liu] change updateStateByKey() to easy API 182be73 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 02d0575 [Davies Liu] add wrapper for foreachRDD() bebeb4a [Davies Liu] address all comments 6db00da [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 8380064 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 52c535b [Davies Liu] remove fix for sum() e108ec1 [Davies Liu] address comments 37fe06f [Davies Liu] use random port for callback server d05871e [Davies Liu] remove reuse of PythonRDD be5e5ff [Davies Liu] merge branch of env, make tests stable. 8071541 [Davies Liu] Merge branch 'env' into streaming c7bbbce [Davies Liu] fix sphinx docs 6bb9d91 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 4d0ea8b [Davies Liu] clear reference of SparkEnv after stop 54bd92b [Davies Liu] improve tests c2b31cb [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 7a88f9f [Davies Liu] rollback RDD.setContext(), use textFileStream() to test checkpointing bd8a4c2 [Davies Liu] fix scala style 7797c70 [Davies Liu] refactor ff88bec [Davies Liu] rename RDDFunction to TransformFunction d328aca [Davies Liu] fix serializer in queueStream 6f0da2f [Davies Liu] recover from checkpoint fa7261b [Davies Liu] refactor a13ff34 [Davies Liu] address comments 8466916 [Davies Liu] support checkpoint 9a16bd1 [Davies Liu] change number of partitions during tests b98d63f [Davies Liu] change private[spark] to private[python] eed6e2a [Davies Liu] rollback not needed changes e00136b [Davies Liu] address comments 069a94c [Davies Liu] fix the number of partitions during window() 338580a [Davies Liu] change _first(), _take(), _collect() as private API 19797f9 [Davies Liu] clean up 6ebceca [Davies Liu] add more tests c40c52d [Davies Liu] change first(), take(n) to has the same behavior as RDD 98ac6c2 [Davies Liu] support ssc.transform() b983f0f [Davies Liu] address comments 847f9b9 [Davies Liu] add more docs, add first(), take() e059ca2 [Davies Liu] move check of window into Python fce0ef5 [Davies Liu] rafactor of foreachRDD() 7001b51 [Davies Liu] refactor of queueStream() 26ea396 [Davies Liu] refactor 74df565 [Davies Liu] fix print and docs b32774c [Davies Liu] move java_import into streaming 604323f [Davies Liu] enable streaming tests c499ba0 [Davies Liu] remove Time and Duration 3f0fb4b [Davies Liu] refactor fix tests c28f520 [Davies Liu] support updateStateByKey d357b70 [Davies Liu] support windowed dstream bd13026 [Davies Liu] fix examples eec401e [Davies Liu] refactor, combine TransformedRDD, fix reuse PythonRDD, fix union 9a57685 [Davies Liu] fix python style bd27874 [Davies Liu] fix scala style 7339be0 [Davies Liu] delete tests 7f53086 [Davies Liu] support transform(), refactor and cleanup df098fc [Davies Liu] Merge branch 'master' into giwa 550dfd9 [giwa] WIP fixing 1.1 merge 5cdb6fa [giwa] changed for SCCallSiteSync e685853 [giwa] meged with rebased 1.1 branch 2d32a74 [giwa] added some StreamingContextTestSuite 4a59e1e [giwa] WIP:added more test for StreamingContext 8ffdbf1 [giwa] added atexit to handle callback server d5f5fcb [giwa] added comment for StreamingContext.sparkContext 63c881a [giwa] added StreamingContext.sparkContext d39f102 [giwa] added StreamingContext.remember d542743 [giwa] clean up code 2fdf0de [Matthew Farrellee] Fix scalastyle errors c0a06bc [giwa] delete not implemented functions f385976 [giwa] delete inproper comments b0f2015 [giwa] added comment in dstream._test_output bebb3f3 [giwa] remove the last brank line fbed8da [giwa] revert pom.xml 8ed93af [giwa] fixed explanaiton 066ba90 [giwa] revert pom.xml fa4af88 [giwa] remove duplicated import 6ae3caa [giwa] revert pom.xml 7dc7391 [giwa] fixed typo 62dc7a3 [giwa] clean up exmples f04882c [giwa] clen up examples b171ec3 [giwa] fixed pep8 violation f198d14 [giwa] clean up code 3166d31 [giwa] clean up c00e091 [giwa] change test case not to use awaitTermination e80647e [giwa] adopted the latest compression way of python command 58e41ff [giwa] merge with master 455e5af [giwa] removed wasted print in DStream af336b7 [giwa] add comments ddd4ee1 [giwa] added TODO coments 99ce042 [giwa] added saveAsTextFiles and saveAsPickledFiles 2a06cdb [giwa] remove waste duplicated code c5ecfc1 [giwa] basic function test cases are passed 8dcda84 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 795b2cd [giwa] broke something 1e126bf [giwa] WIP: solved partitioned and None is not recognized f67cf57 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test 953deb0 [giwa] edited the comment to add more precise description af610d3 [giwa] removed unnesessary changes c1d546e [giwa] fixed PEP-008 violation 99410be [giwa] delete waste file b3b0362 [giwa] added basic operation test cases 9cde7c9 [giwa] WIP added test case bd3ba53 [giwa] WIP 5c04a5f [giwa] WIP: added PythonTestInputStream 019ef38 [giwa] WIP 1934726 [giwa] update comment 376e3ac [giwa] WIP 932372a [giwa] clean up dstream.py 0b09cff [giwa] added stop in StreamingContext 92e333e [giwa] implemented reduce and count function in Dstream 1b83354 [giwa] Removed the waste line 88f7506 [Ken Takagiwa] Kill py4j callback server properly 54b5358 [Ken Takagiwa] tried to restart callback server 4f07163 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. fe02547 [Ken Takagiwa] remove waste file 2ad7bd3 [Ken Takagiwa] clean up codes 6197a11 [Ken Takagiwa] clean up code eb4bf48 [Ken Takagiwa] fix map function 98c2a00 [Ken Takagiwa] added count operation but this implementation need double check 58591d2 [Ken Takagiwa] reduceByKey is working 0df7111 [Ken Takagiwa] delete old file f485b1d [Ken Takagiwa] fied input of socketTextDStream dd6de81 [Ken Takagiwa] initial commit for socketTextStream 247fd74 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 4bcb318 [Ken Takagiwa] implementing transform function in Python 38adf95 [Ken Takagiwa] added reducedByKey not working yet 66fcfff [Ken Takagiwa] modify dstream.py to fix indent error 41886c2 [Ken Takagiwa] comment PythonDStream.PairwiseDStream 0b99bec [Ken] initial commit for pySparkStreaming c214199 [giwa] added testcase for combineByKey 5625bdc [giwa] added gorupByKey testcase 10ab87b [giwa] added sparkContext as input parameter in StreamingContext 10b5b04 [giwa] removed wasted print in DStream e54f986 [giwa] add comments 16aa64f [giwa] added TODO coments 74535d4 [giwa] added saveAsTextFiles and saveAsPickledFiles f76c182 [giwa] remove waste duplicated code 18c8723 [giwa] modified streaming test case to add coment 13fb44c [giwa] basic function test cases are passed 3000b2b [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 ff14070 [giwa] broke something bcdec33 [giwa] WIP: solved partitioned and None is not recognized 270a9e1 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test bb10956 [giwa] edited the comment to add more precise description 253a863 [giwa] removed unnesessary changes 3d37822 [giwa] fixed PEP-008 violation f21cab3 [giwa] delete waste file 878bad7 [giwa] added basic operation test cases ce2acd2 [giwa] WIP added test case 9ad6855 [giwa] WIP 1df77f5 [giwa] WIP: added PythonTestInputStream 1523b66 [giwa] WIP 8a0fbbc [giwa] update comment fe648e3 [giwa] WIP 29c2bc5 [giwa] initial commit for testcase 4d40d63 [giwa] clean up dstream.py c462bb3 [giwa] added stop in StreamingContext d2c01ba [giwa] clean up examples 3c45cd2 [giwa] implemented reduce and count function in Dstream b349649 [giwa] Removed the waste line 3b498e1 [Ken Takagiwa] Kill py4j callback server properly 84a9668 [Ken Takagiwa] tried to restart callback server 9ab8952 [Tathagata Das] Added extra line. 05e991b [Tathagata Das] Added missing file b1d2a30 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. 678e854 [Ken Takagiwa] remove waste file 0a8bbbb [Ken Takagiwa] clean up codes bab31c1 [Ken Takagiwa] clean up code 72b9738 [Ken Takagiwa] fix map function d3ee86a [Ken Takagiwa] added count operation but this implementation need double check 15feea9 [Ken Takagiwa] edit python sparkstreaming example 6f98e50 [Ken Takagiwa] reduceByKey is working c455c8d [Ken Takagiwa] added reducedByKey not working yet dc6995d [Ken Takagiwa] delete old file b31446a [Ken Takagiwa] fixed typo of network_workdcount.py ccfd214 [Ken Takagiwa] added doctest for pyspark.streaming.duration 0d1b954 [Ken Takagiwa] fied input of socketTextDStream f746109 [Ken Takagiwa] initial commit for socketTextStream bb7ccf3 [Ken Takagiwa] remove unused import in python 224fc5e [Ken Takagiwa] add empty line d2099d8 [Ken Takagiwa] sorted the import following Spark coding convention 5bac7ec [Ken Takagiwa] revert streaming/pom.xml e1df940 [Ken Takagiwa] revert pom.xml 494cae5 [Ken Takagiwa] remove not implemented DStream functions in python 17a74c6 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 1a0f065 [Ken Takagiwa] implementing transform function in Python d7b4d6f [Ken Takagiwa] added reducedByKey not working yet 87438e2 [Ken Takagiwa] modify dstream.py to fix indent error b406252 [Ken Takagiwa] comment PythonDStream.PairwiseDStream 454981d [Ken] initial commit for pySparkStreaming 150b94c [giwa] added some StreamingContextTestSuite f7bc8f9 [giwa] WIP:added more test for StreamingContext ee50c5a [giwa] added atexit to handle callback server fdc9125 [giwa] added comment for StreamingContext.sparkContext f5bfb70 [giwa] added StreamingContext.sparkContext da09768 [giwa] added StreamingContext.remember d68b568 [giwa] clean up code 4afa390 [giwa] clean up code 1fd6bc7 [Ken Takagiwa] Merge pull request #2 from mattf/giwa-master d9d59fe [Matthew Farrellee] Fix scalastyle errors 67473a9 [giwa] delete not implemented functions c97377c [giwa] delete inproper comments 2ea769e [giwa] added comment in dstream._test_output 3b27bd4 [giwa] remove the last brank line acfcaeb [giwa] revert pom.xml 93f7637 [giwa] fixed explanaiton 50fd6f9 [giwa] revert pom.xml 4f82c89 [giwa] remove duplicated import 9d1de23 [giwa] revert pom.xml 7339df2 [giwa] fixed typo 9c85e48 [giwa] clean up exmples 24f95db [giwa] clen up examples 0d30109 [giwa] fixed pep8 violation b7dab85 [giwa] improve test case 583e66d [giwa] move tests for streaming inside streaming directory 1d84142 [giwa] remove unimplement test f0ea311 [giwa] clean up code 171edeb [giwa] clean up 4dedd2d [giwa] change test case not to use awaitTermination 268a6a5 [giwa] Changed awaitTermination not to call awaitTermincation in Scala. Just use time.sleep instread 09a28bf [giwa] improve testcases 58150f5 [giwa] Changed the test case to focus the test operation 199e37f [giwa] adopted the latest compression way of python command 185fdbf [giwa] merge with master f1798c4 [giwa] merge with master e70f706 [giwa] added testcase for combineByKey e162822 [giwa] added gorupByKey testcase 97742fe [giwa] added sparkContext as input parameter in StreamingContext 14d4c0e [giwa] removed wasted print in DStream 6d8190a [giwa] add comments 4aa99e4 [giwa] added TODO coments e9fab72 [giwa] added saveAsTextFiles and saveAsPickledFiles 94f2b65 [giwa] remove waste duplicated code 580fbc2 [giwa] modified streaming test case to add coment 99e4bb3 [giwa] basic function test cases are passed 7051a84 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 35933e1 [giwa] broke something 9767712 [giwa] WIP: solved partitioned and None is not recognized 4f2d7e6 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test 33c0f94d [giwa] edited the comment to add more precise description 774f18d [giwa] removed unnesessary changes 3a671cc [giwa] remove export PYSPARK_PYTHON in spark submit 8efa266 [giwa] fixed PEP-008 violation fa75d71 [giwa] delete waste file 7f96294 [giwa] added basic operation test cases 3dda31a [giwa] WIP added test case 1f68b78 [giwa] WIP c05922c [giwa] WIP: added PythonTestInputStream 1fd12ae [giwa] WIP c880a33 [giwa] update comment 5d22c92 [giwa] WIP ea4b06b [giwa] initial commit for testcase 5a9b525 [giwa] clean up dstream.py 79c5809 [giwa] added stop in StreamingContext 189dcea [giwa] clean up examples b8d7d24 [giwa] implemented reduce and count function in Dstream b6468e6 [giwa] Removed the waste line b47b5fd [Ken Takagiwa] Kill py4j callback server properly 19ddcdd [Ken Takagiwa] tried to restart callback server c9fc124 [Tathagata Das] Added extra line. 4caae3f [Tathagata Das] Added missing file 4eff053 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. 5e822d4 [Ken Takagiwa] remove waste file aeaf8a5 [Ken Takagiwa] clean up codes 9fa249b [Ken Takagiwa] clean up code 05459c6 [Ken Takagiwa] fix map function a9f4ecb [Ken Takagiwa] added count operation but this implementation need double check d1ee6ca [Ken Takagiwa] edit python sparkstreaming example 0b8b7d0 [Ken Takagiwa] reduceByKey is working d25d5cf [Ken Takagiwa] added reducedByKey not working yet 7f7c5d1 [Ken Takagiwa] delete old file 967dc26 [Ken Takagiwa] fixed typo of network_workdcount.py 57fb740 [Ken Takagiwa] added doctest for pyspark.streaming.duration 4b69fb1 [Ken Takagiwa] fied input of socketTextDStream 02f618a [Ken Takagiwa] initial commit for socketTextStream 4ce4058 [Ken Takagiwa] remove unused import in python 856d98e [Ken Takagiwa] add empty line 490e338 [Ken Takagiwa] sorted the import following Spark coding convention 5594bd4 [Ken Takagiwa] revert pom.xml 2adca84 [Ken Takagiwa] remove not implemented DStream functions in python e551e13 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit 3758175 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit c5518b4 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 dcf243f [Ken Takagiwa] implementing transform function in Python 9af03f4 [Ken Takagiwa] added reducedByKey not working yet 6e0d9c7 [Ken Takagiwa] modify dstream.py to fix indent error e497b9b [Ken Takagiwa] comment PythonDStream.PairwiseDStream 5c3a683 [Ken] initial commit for pySparkStreaming 665bfdb [giwa] added testcase for combineByKey a3d2379 [giwa] added gorupByKey testcase 636090a [giwa] added sparkContext as input parameter in StreamingContext e7ebb08 [giwa] removed wasted print in DStream d8b593b [giwa] add comments ea9c873 [giwa] added TODO coments 89ae38a [giwa] added saveAsTextFiles and saveAsPickledFiles e3033fc [giwa] remove waste duplicated code a14c7e1 [giwa] modified streaming test case to add coment 536def4 [giwa] basic function test cases are passed 2112638 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 080541a [giwa] broke something 0704b86 [giwa] WIP: solved partitioned and None is not recognized 90a6484 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test a65f302 [giwa] edited the comment to add more precise description bdde697 [giwa] removed unnesessary changes e8c7bfc [giwa] remove export PYSPARK_PYTHON in spark submit 3334169 [giwa] fixed PEP-008 violation db0a303 [giwa] delete waste file 2cfd3a0 [giwa] added basic operation test cases 90ae568 [giwa] WIP added test case a120d07 [giwa] WIP f671cdb [giwa] WIP: added PythonTestInputStream 56fae45 [giwa] WIP e35e101 [giwa] Merge branch 'master' into testcase ba5112d [giwa] update comment 28aa56d [giwa] WIP fb08559 [giwa] initial commit for testcase a613b85 [giwa] clean up dstream.py c40c0ef [giwa] added stop in StreamingContext 31e4260 [giwa] clean up examples d2127d6 [giwa] implemented reduce and count function in Dstream 48f7746 [giwa] Removed the waste line 0f83eaa [Ken Takagiwa] delete py4j 0.8.1 1679808 [Ken Takagiwa] Kill py4j callback server properly f96cd4e [Ken Takagiwa] tried to restart callback server fe86198 [Ken Takagiwa] add py4j 0.8.2.1 but server is not launched 1064fe0 [Ken Takagiwa] Merge branch 'master' of https://github.com/giwa/spark 28c6620 [Ken Takagiwa] Implemented DStream.foreachRDD in the Python API using Py4J callback server 85b0fe1 [Ken Takagiwa] Merge pull request #1 from tdas/python-foreach 54e2e8c [Tathagata Das] Added extra line. e185338 [Tathagata Das] Added missing file a778d4b [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. cc2092b [Ken Takagiwa] remove waste file d042ac6 [Ken Takagiwa] clean up codes 84a021f [Ken Takagiwa] clean up code bd20e17 [Ken Takagiwa] fix map function d01a125 [Ken Takagiwa] added count operation but this implementation need double check 7d05109 [Ken Takagiwa] merge with remote branch ae464e0 [Ken Takagiwa] edit python sparkstreaming example 04af046 [Ken Takagiwa] reduceByKey is working 3b6d7b0 [Ken Takagiwa] implementing transform function in Python 571d52d [Ken Takagiwa] added reducedByKey not working yet 5720979 [Ken Takagiwa] delete old file e604fcb [Ken Takagiwa] fixed typo of network_workdcount.py 4b7c08b [Ken Takagiwa] Merge branch 'master' of https://github.com/giwa/spark ce7d426 [Ken Takagiwa] added doctest for pyspark.streaming.duration a8c9fd5 [Ken Takagiwa] fixed for socketTextStream a61fa9e [Ken Takagiwa] fied input of socketTextDStream 1e84f41 [Ken Takagiwa] initial commit for socketTextStream 6d012f7 [Ken Takagiwa] remove unused import in python 25d30d5 [Ken Takagiwa] add empty line 6e0a64a [Ken Takagiwa] sorted the import following Spark coding convention fa4a7fc [Ken Takagiwa] revert streaming/pom.xml 8f8202b [Ken Takagiwa] revert streaming pom.xml c9d79dd [Ken Takagiwa] revert pom.xml 57e3e52 [Ken Takagiwa] remove not implemented DStream functions in python 0a516f5 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit a7a0b5c [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit 72bfc66 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 69e9cd3 [Ken Takagiwa] implementing transform function in Python 94a0787 [Ken Takagiwa] added reducedByKey not working yet 88068cf [Ken Takagiwa] modify dstream.py to fix indent error 1367be5 [Ken Takagiwa] comment PythonDStream.PairwiseDStream eb2b3ba [Ken] Merge remote-tracking branch 'upstream/master' d8e51f9 [Ken] initial commit for pySparkStreaming --- .../apache/spark/api/python/PythonRDD.scala | 10 +- .../main/python/streaming/hdfs_wordcount.py | 49 ++ .../python/streaming/network_wordcount.py | 48 ++ .../streaming/stateful_network_wordcount.py | 57 ++ python/docs/epytext.py | 2 +- python/docs/index.rst | 1 + python/docs/pyspark.rst | 3 +- python/pyspark/context.py | 8 +- python/pyspark/serializers.py | 3 + python/pyspark/streaming/__init__.py | 21 + python/pyspark/streaming/context.py | 325 +++++++++ python/pyspark/streaming/dstream.py | 621 ++++++++++++++++++ python/pyspark/streaming/tests.py | 545 +++++++++++++++ python/pyspark/streaming/util.py | 128 ++++ python/run-tests | 7 + .../streaming/api/java/JavaDStreamLike.scala | 2 +- .../streaming/api/python/PythonDStream.scala | 316 +++++++++ 17 files changed, 2133 insertions(+), 13 deletions(-) create mode 100644 examples/src/main/python/streaming/hdfs_wordcount.py create mode 100644 examples/src/main/python/streaming/network_wordcount.py create mode 100644 examples/src/main/python/streaming/stateful_network_wordcount.py create mode 100644 python/pyspark/streaming/__init__.py create mode 100644 python/pyspark/streaming/context.py create mode 100644 python/pyspark/streaming/dstream.py create mode 100644 python/pyspark/streaming/tests.py create mode 100644 python/pyspark/streaming/util.py create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index c74f86548ef85..4acbdf9d5e25f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -42,7 +40,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -55,9 +53,9 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -234,7 +232,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py new file mode 100644 index 0000000000000..40faff0ccc7db --- /dev/null +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in new text files created in the given directory + Usage: hdfs_wordcount.py + is the directory that Spark Streaming will use to find and read new text files. + + To run this on your local machine on directory `localdir`, run this example + $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir + + Then create a text file in `localdir` and the words in the file will get counted. +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: hdfs_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingHDFSWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.textFileStream(sys.argv[1]) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda x: (x, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py new file mode 100644 index 0000000000000..cfa9c1ff5bfbc --- /dev/null +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingNetworkWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py new file mode 100644 index 0000000000000..18a9a5a452ffb --- /dev/null +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: stateful_network_wordcount.py + and describe the TCP server that Spark Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ + localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: stateful_network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") + ssc = StreamingContext(sc, 1) + ssc.checkpoint("checkpoint") + + def updateFunc(new_values, last_sum): + return sum(new_values) + (last_sum or 0) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + running_counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .updateStateByKey(updateFunc) + + running_counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 61d731bff570d..19fefbfc057a4 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -5,7 +5,7 @@ (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), - (r"[IBCM]{(.+)}", r"`\1`"), + (r"[IBCM]{([^}]+)}", r"`\1`"), ('pyspark.rdd.RDD', 'RDD'), ) diff --git a/python/docs/index.rst b/python/docs/index.rst index d66e051b15371..703bef644de28 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -13,6 +13,7 @@ Contents: pyspark pyspark.sql + pyspark.streaming pyspark.mllib diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index a68bd62433085..e81be3b6cb796 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -7,8 +7,9 @@ Subpackages .. toctree:: :maxdepth: 1 - pyspark.mllib pyspark.sql + pyspark.streaming + pyspark.mllib Contents -------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 85c04624da4a6..89d2e2e5b4a8e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -68,7 +68,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None): + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -104,14 +104,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -154,7 +154,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 3d1a34b281acc..08a0f0d8ffb3e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -114,6 +114,9 @@ def __ne__(self, other): def __repr__(self): return "<%s object>" % self.__class__.__name__ + def __hash__(self): + return hash(str(self)) + class FramedSerializer(Serializer): diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py new file mode 100644 index 0000000000000..d2644a1d4ffab --- /dev/null +++ b/python/pyspark/streaming/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.dstream import DStream + +__all__ = ['StreamingContext', 'DStream'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..dc9dc41121935 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys + +from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import, JavaObject + +from pyspark import RDD, SparkConf +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer + +__all__ = ["StreamingContext"] + + +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. + + Also, it will update the port number (0) with real port + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + +class StreamingContext(object): + """ + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream} various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTransformation()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. + """ + _transformerSerializer = None + + def __init__(self, sparkContext, batchDuration=None, jssc=None): + """ + Create a new StreamingContext. + + @param sparkContext: L{SparkContext} object. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches + """ + + self._sc = sparkContext + self._jvm = self._sc._jvm + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) + + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + + # register serializer for TransformFunction + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + + @classmethod + def getOrCreate(cls, checkpointPath, setupFunc): + """ + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. + + @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc Function to create a new JavaStreamingContext and setup DStreams + """ + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + ssc = setupFunc() + ssc.checkpoint(checkpointPath) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(checkpointPath) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) + + @property + def sparkContext(self): + """ + Return SparkContext which is associated with this StreamingContext. + """ + return self._sc + + def start(self): + """ + Start the execution of the streams. + """ + self._jssc.start() + + def awaitTermination(self, timeout=None): + """ + Wait for the execution to stop. + @param timeout: time to wait in seconds + """ + if timeout is None: + self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(int(timeout * 1000)) + + def stop(self, stopSparkContext=True, stopGraceFully=False): + """ + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext: Stop the associated SparkContext or not + @param stopGracefully: Stop gracefully by waiting for the processing + of all received data to be completed + """ + self._jssc.stop(stopSparkContext, stopGraceFully) + if stopSparkContext: + self._sc.stop() + + def remember(self, duration): + """ + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs (if the developer wishes to query old data outside the + DStream computation). + + @param duration: Minimum duration (in seconds) that each DStream + should remember its RDDs + """ + self._jssc.remember(self._jduration(duration)) + + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory: HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input from TCP source hostname:port. Data is received using + a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited + lines. + + @param hostname: Hostname to connect to for receiving data + @param port: Port to connect to for receiving data + @param storageLevel: Storage level to use for storing the received objects + """ + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) + + def textFileStream(self, directory): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as text files. Files must be wrriten to the + monitored directory by "moving" them from another location within the same + file system. File names starting with . are ignored. + """ + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def _check_serializers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i]._reserialize() + + def queueStream(self, rdds, oneAtATime=True, default=None): + """ + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + + @param rdds: Queue of RDDs + @param oneAtATime: pick one rdd each time or pick all of them once. + @param default: The default rdd if no more in rdds + """ + if default and not isinstance(default, RDD): + default = self._sc.parallelize(default) + + if not rdds and default: + rdds = [rdds] + + if rdds and not isinstance(rdds[0], RDD): + rdds = [self._sc.parallelize(input) for input in rdds] + self._check_serializers(rdds) + + jrdds = ListConverter().convert([r._jrdd for r in rdds], + SparkContext._gateway._gateway_client) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + default = default._reserialize(rdds[0]._jrdd_deserializer) + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. + """ + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + func = TransformFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.TransformFunction(func) + jdstream = self._jssc.transform(jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..5ae5cf07f0137 --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,621 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import chain, ifilter, imap +import operator +import time +from datetime import datetime + +from py4j.protocol import Py4JJavaError + +from pyspark import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.util import rddToFileName, TransformFunction +from pyspark.rdd import portable_hash +from pyspark.resultiterable import ResultIterable + +__all__ = ["DStream"] + + +class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self._sc = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + + def context(self): + """ + Return the StreamingContext associated with this DStream + """ + return self._ssc + + def count(self): + """ + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) + + def filter(self, f): + """ + Return a new DStream containing only the elements that satisfy predicate. + """ + def func(iterator): + return ifilter(f, iterator) + return self.mapPartitions(func, True) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results + """ + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def map(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each element of DStream. + """ + def func(iterator): + return imap(f, iterator) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. + """ + def func(s, iterator): + return f(iterator) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. + """ + return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) + + def reduce(self, func): + """ + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. + """ + return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) + + def reduceByKey(self, func, numPartitions=None): + """ + Return a new DStream by applying reduceByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.combineByKey(lambda x: x, func, func, numPartitions) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions=None): + """ + Return a new DStream by applying combineByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def func(rdd): + return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) + return self.transform(func) + + def partitionBy(self, numPartitions, partitionFunc=portable_hash): + """ + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. + """ + return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.func_code.co_argcount == 1: + old_func = func + func = lambda t, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def pprint(self): + """ + Print the first ten elements of each RDD generated in this DStream. + """ + def takeAndPrint(time, rdd): + taken = rdd.take(11) + print "-------------------------------------------" + print "Time: %s" % time + print "-------------------------------------------" + for record in taken[:10]: + print record + if len(taken) > 10: + print "..." + print + + self.foreachRDD(takeAndPrint) + + def mapValues(self, f): + """ + Return a new DStream by applying a map function to the value of + each key-value pairs in this DStream without changing the key. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + def flatMapValues(self, f): + """ + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in this DStream without changing the key. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def glom(self): + """ + Return a new DStream in which RDD is generated by applying glom() + to RDD of this DStream. + """ + def func(iterator): + yield list(iterator) + return self.mapPartitions(func) + + def cache(self): + """ + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Persist the RDDs of this DStream with the given storage level + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): + """ + Enable periodic checkpointing of RDDs of this DStream + + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed + """ + self.is_checkpointed = True + self._jdstream.checkpoint(self._ssc._jduration(interval)) + return self + + def groupByKey(self, numPartitions=None): + """ + Return a new DStream by applying groupByKey on each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) + + def countByValue(self): + """ + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. + """ + return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + + def saveAsTextFiles(self, prefix, suffix=None): + """ + Save each RDD in this DStream as at text file, using string + representation of elements. + """ + def saveAsTextFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsTextFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise + return self.foreachRDD(saveAsTextFile) + + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.func_code.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func) + + def transformWith(self, func, other, keepSerializer=False): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) + """ + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + other._jdstream.dstream(), jfunc) + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer + return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) + + def repartition(self, numPartitions): + """ + Return a new DStream with an increased or decreased level of parallelism. + """ + return self.transform(lambda rdd: rdd.repartition(numPartitions)) + + @property + def _slideDuration(self): + """ + Return the slideDuration in seconds of this DStream + """ + return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 + + def union(self, other): + """ + Return a new DStream by unifying data of another DStream with this DStream. + + @param other: Another DStream having the same interval (i.e., slideDuration) + as this DStream. + """ + if self._slideDuration != other._slideDuration: + raise ValueError("the two DStream should have same slide duration") + return self.transformWith(lambda a, b: a.union(b), other, True) + + def cogroup(self, other, numPartitions=None): + """ + Return a new DStream by applying 'cogroup' between RDDs of this + DStream and `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + + def join(self, other, numPartitions=None): + """ + Return a new DStream by applying 'join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.join(b, numPartitions), other) + + def leftOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'left outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) + + def rightOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'right outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) + + def fullOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'full outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) + + def _jtime(self, timestamp): + """ Convert datetime or unix_timestamp into Time + """ + if isinstance(timestamp, datetime): + timestamp = time.mktime(timestamp.timetuple()) + return self._sc._jvm.Time(long(timestamp * 1000)) + + def slice(self, begin, end): + """ + Return all the RDDs between 'begin' to 'end' (both included) + + `begin`, `end` could be datetime.datetime() or unix_timestamp + """ + jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] + + def _validate_window_param(self, window, slide): + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(window * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if slide and int(slide * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + + def window(self, windowDuration, slideDuration=None): + """ + Return a new DStream in which each RDD contains all the elements in seen in a + sliding window of time over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + self._validate_window_param(windowDuration, slideDuration) + d = self._ssc._jduration(windowDuration) + if slideDuration is None: + return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) + s = self._ssc._jduration(slideDuration) + return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) + + def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated by reducing all + elements in a sliding window over this DStream. + + if `invReduceFunc` is not None, the reduction is done incrementally + using the old window's reduced value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse reduce function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + keyed = self.map(lambda x: (1, x)) + reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, + windowDuration, slideDuration, 1) + return reduced.map(lambda (k, v): v) + + def countByWindow(self, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated + by counting the number of elements in a window over this DStream. + windowDuration and slideDuration are as defined in the window() operation. + + This is equivalent to window(windowDuration, slideDuration).count(), + but will be more efficient if window is large. + """ + return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub, + windowDuration, slideDuration) + + def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream in which each RDD contains the count of distinct elements in + RDDs in a sliding window over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + """ + keyed = self.map(lambda x: (x, 1)) + counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, + windowDuration, slideDuration, numPartitions) + return counted.filter(lambda (k, v): v > 0).count() + + def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream by applying `groupByKey` over a sliding window. + Similar to `DStream.groupByKey()`, but applies it over a sliding window. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: Number of partitions of each RDD in the new DStream. + """ + ls = self.mapValues(lambda x: [x]) + grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], + windowDuration, slideDuration, numPartitions) + return grouped.mapValues(ResultIterable) + + def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None, + numPartitions=None, filterFunc=None): + """ + Return a new DStream by applying incremental `reduceByKey` over a sliding window. + + The reduced value of over a new window is calculated using the old window's reduce value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + + `invFunc` can be None, then it will reduce all the RDDs in window, could be slower + than having `invFunc`. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + @param filterFunc: function to filter expired key-value pairs; + only pairs that satisfy the function are retained + set this to null if you do not want to filter + """ + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) + + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) + if invReduceFunc: + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + else: + jinvReduceFunc = None + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + Return a new "state" DStream where the state for each key is updated by applying + the given function on the previous state of the key and the new values of the key. + + @param updateFunc: State update function. If this function returns None, then + corresponding state key-value pair will be eliminated. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def reduceFunc(t, a, b): + if a is None: + g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) + else: + g = a.cogroup(b, numPartitions) + g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) + state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) + return state.filter(lambda (k, v): v is not None) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + +class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ + def __init__(self, prev, func): + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer + self.is_cached = False + self.is_checkpointed = False + self._jdstream_val = None + + if (isinstance(prev, TransformedDStream) and + not prev.is_cached and not prev.is_checkpointed): + prev_func = prev.func + self.func = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev = prev.prev + else: + self.prev = prev + self.func = func + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000000000..a8d876d0fa3b3 --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,545 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from itertools import chain +import time +import operator +import unittest +import tempfile + +from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.streaming.context import StreamingContext + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 10 # seconds + duration = 1 + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) + self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + self.ssc.stop() + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + +class BasicOperationTests(PySparkStreamingTestCase): + + def test_map(self): + """Basic operation test for DStream.map.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.map(str) + expected = map(lambda x: map(str, x), input) + self._test_func(input, func, expected) + + def test_flatMap(self): + """Basic operation test for DStream.faltMap.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + input) + self._test_func(input, func, expected) + + def test_filter(self): + """Basic operation test for DStream.filter.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + self._test_func(input, func, expected) + + def test_count(self): + """Basic operation test for DStream.count.""" + input = [range(5), range(10), range(20)] + + def func(dstream): + return dstream.count() + expected = map(lambda x: [len(x)], input) + self._test_func(input, func, expected) + + def test_reduce(self): + """Basic operation test for DStream.reduce.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.reduce(operator.add) + expected = map(lambda x: [reduce(operator.add, x)], input) + self._test_func(input, func, expected) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey.""" + input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def func(dstream): + return dstream.reduceByKey(operator.add) + expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + self._test_func(input, func, expected, sort=True) + + def test_mapValues(self): + """Basic operation test for DStream.mapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + self._test_func(input, func, expected, sort=True) + + def test_flatMapValues(self): + """Basic operation test for DStream.flatMapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + self._test_func(input, func, expected) + + def test_glom(self): + """Basic operation test for DStream.glom.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.glom() + expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + self._test_func(rdds, func, expected) + + def test_mapPartitions(self): + """Basic operation test for DStream.mapPartitions.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected = [[3, 7], [11, 15], [19, 23]] + self._test_func(rdds, func, expected) + + def test_countByValue(self): + """Basic operation test for DStream.countByValue.""" + input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def func(dstream): + return dstream.countByValue() + expected = [[4], [4], [3]] + self._test_func(input, func, expected) + + def test_groupByKey(self): + """Basic operation test for DStream.groupByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + return dstream.groupByKey().mapValues(list) + + expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + self._test_func(input, func, expected, sort=True) + + def test_combineByKey(self): + """Basic operation test for DStream.combineByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + self._test_func(input, func, expected, sort=True) + + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartition(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + + def test_union(self): + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] + + def func(d1, d2): + return d1.union(d2) + + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_update_state_by_key(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + +class WindowFunctionTests(PySparkStreamingTestCase): + + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.countByWindow(3, 1) + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window_large(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByWindow(5, 1) + + expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] + self._test_func(input, func, expected) + + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(5, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + + def test_group_by_key_and_window(self): + input = [[('a', i)] for i in range(5)] + + def func(dstream): + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + + expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], + [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] + self._test_func(input, func, expected) + + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) + + def test_union(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + +class CheckpointTests(PySparkStreamingTestCase): + + def setUp(self): + pass + + def test_get_or_create(self): + inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(vs, s): + return sum(vs, s or 0) + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.5) + return ssc + + cpd = tempfile.mkdtemp("test_streaming_cps") + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) + ssc.stop(True, True) + + time.sleep(1) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000000000..86ee5aa04f252 --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from datetime import datetime +import traceback + +from pyspark import SparkContext, RDD + + +class TransformFunction(object): + """ + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. + """ + _emptyRDD = None + + def __init__(self, ctx, func, *deserializers): + self.ctx = ctx + self.func = func + self.deserializers = deserializers + + def call(self, milliseconds, jrdds): + try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return + + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) + + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + for jrdd, ser in zip(jrdds, sers)] + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) + if r: + return r._jrdd + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunction(%s)" % self.func + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] + + +class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ + def __init__(self, ctx, serializer, gateway=None): + self.ctx = ctx + self.serializer = serializer + self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) + + def dumps(self, id): + try: + func = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return TransformFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] + + +def rddToFileName(prefix, suffix, timestamp): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + if suffix is None: + return prefix + "-" + str(timestamp) + else: + return prefix + "-" + str(timestamp) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/run-tests b/python/run-tests index f6a96841175e8..2f98443c30aef 100755 --- a/python/run-tests +++ b/python/run-tests @@ -81,6 +81,11 @@ function run_mllib_tests() { run_test "pyspark/mllib/tests.py" } +function run_streaming_tests() { + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -96,6 +101,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -105,6 +111,7 @@ if [ $(which pypy) ]; then run_core_tests run_sql_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de4e83c1..2a7004e56ef53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } - /** + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000000000..213dff6a76354 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.language.existentials + +import py4j.GatewayServer + +import org.apache.spark.api.java._ +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Interval, Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ + + +/** + * Interface for Python callback function which is used to transform RDDs + */ +private[python] trait PythonTransformFunction { + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] +} + +/** + * Interface for Python Serializer to serialize PythonTransformFunction + */ +private[python] trait PythonTransformFunctionSerializer { + def dumps(id: String): Array[Byte] + def loads(bytes: Array[Byte]): PythonTransformFunction +} + +/** + * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J) + * so that it looks like a Scala function and can be transparently serialized and + * deserialized by Java. + */ +private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) + .map(_.rdd) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + } + + // for function.Function2 + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } + + private def writeObject(out: ObjectOutputStream): Unit = { + val bytes = PythonTransformFunctionSerializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonTransformFunctionSerializer.deserialize(bytes) + } +} + +/** + * Helpers for PythonTransformFunctionSerializer + * + * PythonTransformFunctionSerializer is logically a singleton that's happens to be + * implemented as a Python object. + */ +private[python] object PythonTransformFunctionSerializer { + + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ + private var serializer: PythonTransformFunctionSerializer = _ + + /* + * Register a serializer from Python, should be called during initialization + */ + def register(ser: PythonTransformFunctionSerializer): Unit = { + serializer = ser + } + + def serialize(func: PythonTransformFunction): Array[Byte] = { + assert(serializer != null, "Serializer has not been registered!") + // get the id of PythonTransformFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] + serializer.dumps(id) + } + + def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + assert(serializer != null, "Serializer has not been registered!") + serializer.loads(bytes) + } +} + +/** + * Helper functions, which are called from Python via Py4J. + */ +private[python] object PythonDStream { + + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ + def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { + PythonTransformFunctionSerializer.register(ser) + } + + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { + val func = new TransformFunction((pfunc)) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + +/** + * Base class for PythonDStream with some common methods + */ +private[python] abstract class PythonDStream( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * Transformed DStream in Python. + */ +private[python] class PythonTransformedDStream ( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends PythonDStream(parent, pfunc) { + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(rdd, validTime) + } else { + None + } + } +} + +/** + * Transformed from two DStreams in Python. + */ +private[python] class PythonTransformed2DStream( + parent: DStream[_], + parent2: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent, parent2) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val empty: RDD[_] = ssc.sparkContext.emptyRDD + val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) + val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty) + func(Some(rdd1), Some(rdd2), validTime) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * similar to StateDStream + */ +private[python] class PythonStateDStream( + parent: DStream[Array[Byte]], + @transient reduceFunc: PythonTransformFunction) + extends PythonDStream(parent, reduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(lastState, rdd, validTime) + } else { + lastState + } + } +} + +/** + * similar to ReducedWindowedDStream + */ +private[python] class PythonReducedWindowedDStream( + parent: DStream[Array[Byte]], + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, + _windowDuration: Duration, + _slideDuration: Duration) + extends PythonDStream(parent, preduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + val invReduceFunc = new TransformFunction(pinvReduceFunc) + + def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val currentTime = validTime + val current = new Interval(currentTime - windowDuration, currentTime) + val previous = current - slideDuration + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + val previousRDD = getOrCompute(previous.endTime) + + // for small window, reduce once will be better than twice + if (pinvReduceFunc != null && previousRDD.isDefined + && windowDuration >= slideDuration * 5) { + + // subtract the values from old RDDs + val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime) + val subtracted = if (oldRDDs.size > 0) { + invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime) + } else { + previousRDD + } + + // add the RDDs of the reduced values in "new time steps" + val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) + if (newRDDs.size > 0) { + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + } else { + subtracted + } + } else { + // Get the RDDs of the reduced values in current window + val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) + if (currentRDDs.size > 0) { + func(None, Some(ssc.sc.union(currentRDDs)), validTime) + } else { + None + } + } + } +} From 18bd67c24b081f113b34455692451571c466df92 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 12 Oct 2014 13:08:42 -0700 Subject: [PATCH 31/38] [SPARK-3887] Send stracktrace in ConnectionManager error replies When reporting that a remote error occurred, the ConnectionManager should also log the stacktrace of the remote exception. This PR accomplishes this by sending the remote exception's stacktrace as the payload in the "negative ACK / error message." Author: Josh Rosen Closes #2741 from JoshRosen/propagate-cm-exceptions-to-sender and squashes the following commits: b5366cc [Josh Rosen] Explicitly encode error messages using UTF-8. cef18b3 [Josh Rosen] [SPARK-3887] Send stracktrace in ConnectionManager error messages. --- .../spark/network/nio/ConnectionManager.scala | 10 ++++++---- .../org/apache/spark/network/nio/Message.scala | 14 ++++++++++++++ .../network/nio/NioBlockTransferService.scala | 11 ++++------- .../spark/network/nio/ConnectionManagerSuite.scala | 6 ++++-- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 6b00190c5eccc..9396b6ba84e7e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -748,9 +748,7 @@ private[nio] class ConnectionManager( } catch { case e: Exception => { logError(s"Exception was thrown while processing message", e) - val m = Message.createBufferMessage(bufferMessage.id) - m.hasError = true - ackMessage = Some(m) + ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) } } finally { sendMessage(connectionManagerId, ackMessage.getOrElse { @@ -913,8 +911,12 @@ private[nio] class ConnectionManager( } case scala.util.Success(ackMessage) => if (ackMessage.hasError) { + val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head + val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) + errorMsgByteBuf.get(errorMsgBytes) + val errorMsg = new String(errorMsgBytes, "utf-8") val e = new IOException( - "sendMessageReliably failed with ACK that signalled a remote error") + s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") if (!promise.tryFailure(e)) { logWarning("Ignore error because promise is completed", e) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index 0b874c2891255..3ad04591da658 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.util.Utils private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -84,6 +85,19 @@ private[nio] object Message { createBufferMessage(new Array[ByteBuffer](0), ackId) } + /** + * Create a "negative acknowledgment" to notify a sender that an error occurred + * while processing its message. The exception's stacktrace will be formatted + * as a string, serialized into a byte array, and sent as the message payload. + */ + def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { + val exceptionString = Utils.exceptionString(exception) + val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8")) + val errorMessage = createBufferMessage(serializedExceptionString, ackId) + errorMessage.hasError = true + errorMessage + } + def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { case BUFFER_MESSAGE => new BufferMessage(header.id, diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..5add4fc433fb3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -151,17 +151,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } catch { case e: Exception => { logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + Some(Message.createErrorMessage(e, msg.id)) } } case otherMessage: Any => - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" + logError(errorMsg) + Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) } } diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 9f49587cdc670..b70734dfe37cf 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -27,6 +27,7 @@ import scala.language.postfixOps import org.scalatest.FunSuite import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. @@ -236,7 +237,7 @@ class ConnectionManagerSuite extends FunSuite { val manager = new ConnectionManager(0, conf, securityManager) val managerServer = new ConnectionManager(0, conf, securityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception + throw new Exception("Custom exception text") }) val size = 10 * 1024 * 1024 @@ -246,9 +247,10 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - intercept[IOException] { + val exception = intercept[IOException] { Await.result(future, 1 second) } + assert(Utils.exceptionString(exception).contains("Custom exception text")) manager.stop() managerServer.stop() From e5be4de7bcf5aa7afc856fc665427ff2b22a0fcd Mon Sep 17 00:00:00 2001 From: NamelessAnalyst Date: Sun, 12 Oct 2014 14:18:55 -0700 Subject: [PATCH 32/38] SPARK-3716 [GraphX] Update Analytics.scala for partitionStrategy assignment Previously, when the val partitionStrategy was created it called a function in the Analytics object which was a copy of the PartitionStrategy.fromString() method. This function has been removed, and the assignment of partitionStrategy now uses the PartitionStrategy.fromString method instead. In this way, it better matches the declarations of edge/vertex StorageLevel variables. Author: NamelessAnalyst Closes #2569 from NamelessAnalyst/branch-1.1 and squashes the following commits: c24ff51 [NamelessAnalyst] Update Analytics.scala (cherry picked from commit 5a21e3e7e97f135c81c664098a723434b910f09d) Signed-off-by: Ankur Dave --- .../spark/examples/graphx/Analytics.scala | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c4317a6aec798..45527d9382fd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -46,17 +46,6 @@ object Analytics extends Logging { } val options = mutable.Map(optionsList: _*) - def pickPartitioner(v: String): PartitionStrategy = { - // TODO: Use reflection rather than listing all the partitioning strategies here. - v match { - case "RandomVertexCut" => RandomVertexCut - case "EdgePartition1D" => EdgePartition1D - case "EdgePartition2D" => EdgePartition2D - case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut - case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v) - } - } - val conf = new SparkConf() .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") @@ -67,7 +56,7 @@ object Analytics extends Logging { sys.exit(1) } val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy") - .map(pickPartitioner(_)) + .map(PartitionStrategy.fromString(_)) val edgeStorageLevel = options.remove("edgeStorageLevel") .map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY) val vertexStorageLevel = options.remove("vertexStorageLevel") @@ -107,7 +96,7 @@ object Analytics extends Logging { if (!outFname.isEmpty) { logWarning("Saving pageranks of pages to " + outFname) - pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname) + pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname) } sc.stop() @@ -129,7 +118,7 @@ object Analytics extends Logging { val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) val cc = ConnectedComponents.run(graph) - println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct()) + println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct()) sc.stop() case "triangles" => @@ -147,7 +136,7 @@ object Analytics extends Logging { minEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) - // TriangleCount requires the graph to be partitioned + // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) println("Triangles: " + triangles.vertices.map { From c86c9760374f331ab7ed173b0a022250635485d3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sun, 12 Oct 2014 15:41:27 -0700 Subject: [PATCH 33/38] [HOTFIX] Fix compilation error for Yarn 2.0.*-alpha This was reported in https://issues.apache.org/jira/browse/SPARK-3445. There are API differences between the 0.23.* and the 2.0.*-alpha branches that are not accounted for when this code was introduced. Author: Andrew Or Closes #2776 from andrewor14/fix-yarn-alpha and squashes the following commits: ec94752 [Andrew Or] Fix compilation error for 2.0.*-alpha --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5a20532315e59..5c7bca4541222 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -122,7 +122,7 @@ private[spark] class Client( * ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API. */ override def getClientToken(report: ApplicationReport): String = - Option(report.getClientToken).getOrElse("") + Option(report.getClientToken).map(_.toString).getOrElse("") } object Client { From fc616d51a510f82627b5be949a5941419834cf70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Dubovsk=C3=BD?= Date: Sun, 12 Oct 2014 22:03:26 -0700 Subject: [PATCH 34/38] [SPARK-3121] Wrong implementation of implicit bytesWritableConverter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit val path = ... //path to seq file with BytesWritable as type of both key and value val file = sc.sequenceFile[Array[Byte],Array[Byte]](path) file.take(1)(0)._1 This prints incorrect content of byte array. Actual content starts with correct one and some "random" bytes and zeros are appended. BytesWritable has two methods: getBytes() - return content of all internal array which is often longer then actual value stored. It usually contains the rest of previous longer values copyBytes() - return just begining of internal array determined by internal length property It looks like in implicit conversion between BytesWritable and Array[byte] getBytes is used instead of correct copyBytes. dbtsai Author: Jakub Dubovský Author: Dubovsky Jakub Closes #2712 from james64/3121-bugfix and squashes the following commits: f85d24c [Jakub Dubovský] Test name changed, comments added 1b20d51 [Jakub Dubovský] Import placed correctly 406e26c [Jakub Dubovský] Scala style fixed f92ffa6 [Dubovsky Jakub] performance tuning 480f9cd [Dubovsky Jakub] Bug 3121 fixed --- .../scala/org/apache/spark/SparkContext.scala | 6 ++- .../org/apache/spark/SparkContextSuite.scala | 40 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/SparkContextSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 396cdd1247e07..b709b8880ba76 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import java.io._ import java.net.URI +import java.util.Arrays import java.util.concurrent.atomic.AtomicInteger import java.util.{Properties, UUID} import java.util.UUID.randomUUID @@ -1429,7 +1430,10 @@ object SparkContext extends Logging { simpleWritableConverter[Boolean, BooleanWritable](_.get) implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) } implicit def stringWritableConverter(): WritableConverter[String] = diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala new file mode 100644 index 0000000000000..31edad1c56c73 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import org.apache.hadoop.io.BytesWritable + +class SparkContextSuite extends FunSuite { + //Regression test for SPARK-3121 + test("BytesWritable implicit conversion is correct") { + val bytesWritable = new BytesWritable() + val inputArray = (1 to 10).map(_.toByte).toArray + bytesWritable.set(inputArray, 0, 10) + bytesWritable.set(inputArray, 0, 5) + + val converter = SparkContext.bytesWritableConverter() + val byteArray = converter.convert(bytesWritable) + assert(byteArray.length === 5) + + bytesWritable.set(inputArray, 0, 0) + val byteArray2 = converter.convert(bytesWritable) + assert(byteArray2.length === 0) + } +} From b4a7fa7a663c462bf537ca9d63af0dba6b4a8033 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sun, 12 Oct 2014 22:48:54 -0700 Subject: [PATCH 35/38] [SPARK-3905][Web UI]The keys for sorting the columns of Executor page ,Stage page Storage page are incorrect Author: GuoQiang Li Closes #2763 from witgo/SPARK-3905 and squashes the following commits: 17d7990 [GuoQiang Li] The keys for sorting the columns of Executor page ,Stage page Storage page are incorrect --- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 12 ++++++------ .../scala/org/apache/spark/ui/jobs/StageTable.scala | 6 +++--- .../org/apache/spark/ui/storage/StoragePage.scala | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 2987dc04494a5..f0e43fbf70976 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -71,19 +71,19 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} + {UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - + {Utils.bytesToString(v.inputBytes)} - + {Utils.bytesToString(v.shuffleRead)} - + {Utils.bytesToString(v.shuffleWrite)} - + {Utils.bytesToString(v.memoryBytesSpilled)} - + {Utils.bytesToString(v.diskBytesSpilled)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2e67310594784..4ee7f08ab47a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -176,9 +176,9 @@ private[ui] class StageTableBase( {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} - {inputReadWithUnit} - {shuffleReadWithUnit} - {shuffleWriteWithUnit} + {inputReadWithUnit} + {shuffleReadWithUnit} + {shuffleWriteWithUnit} } /** Render an HTML row that represents a stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 716591c9ed449..83489ca0679ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -58,9 +58,9 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.numCachedPartitions} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.tachyonSize)} - {Utils.bytesToString(rdd.diskSize)} + {Utils.bytesToString(rdd.memSize)} + {Utils.bytesToString(rdd.tachyonSize)} + {Utils.bytesToString(rdd.diskSize)} // scalastyle:on } From d8b8c210786dfb905d06ea0a21d633f7772d5d1a Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Sun, 12 Oct 2014 23:05:14 -0700 Subject: [PATCH 36/38] Add echo "Run streaming tests ..." Author: Ken Takagiwa Closes #2778 from giwa/patch-2 and squashes the following commits: a59f9a1 [Ken Takagiwa] Add echo "Run streaming tests ..." --- python/run-tests | 1 + 1 file changed, 1 insertion(+) diff --git a/python/run-tests b/python/run-tests index 2f98443c30aef..80acd002ab7eb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -82,6 +82,7 @@ function run_mllib_tests() { } function run_streaming_tests() { + echo "Run streaming tests ..." run_test "pyspark/streaming/util.py" run_test "pyspark/streaming/tests.py" } From 92e017fb894be1e8e2b2b5274fec4c31a7a4412e Mon Sep 17 00:00:00 2001 From: w00228970 Date: Sun, 12 Oct 2014 23:35:50 -0700 Subject: [PATCH 37/38] [SPARK-3899][Doc]fix wrong links in streaming doc There are three [Custom Receiver Guide] links in streaming doc, the first is wrong. Author: w00228970 Author: wangfei Closes #2749 from scwf/streaming-doc and squashes the following commits: 0cd76b7 [wangfei] update link tojump to the Akka-specific section 45b0646 [w00228970] wrong link in streaming doc --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 5c21e912ea160..738309c668387 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -494,7 +494,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. From 942847fd94c920f7954ddf01f97263926e512b0e Mon Sep 17 00:00:00 2001 From: omgteam Date: Mon, 13 Oct 2014 09:59:41 -0700 Subject: [PATCH 38/38] Bug Fix: without unpersist method in RandomForest.scala During trainning Gradient Boosting Decision Tree on large-scale sparse data, spark spill hundreds of data onto disk. And find the bug below: In version 1.1.0 DecisionTree.scala, train Method, treeInput has been persisted in Memory, but without unpersist. It caused heavy DISK usage. In github version(1.2.0 maybe), RandomForest.scala, train Method, baggedInput has been persisted but without unpersisted too. After added unpersist, it works right. https://issues.apache.org/jira/browse/SPARK-3918 Author: omgteam Closes #2775 from omgteam/master and squashes the following commits: 815d543 [omgteam] adjust tab to spaces 1a36f83 [omgteam] Bug: fix without unpersist baggedInput in RandomForest.scala --- .../main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index fa7a26f17c3ca..ebbd8e0257209 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -176,6 +176,8 @@ private class RandomForest ( timer.stop("findBestSplits") } + baggedInput.unpersist() + timer.stop("total") logInfo("Internal timing for DecisionTree:")