From bcecd73fdd4d2ec209259cfd57d3ad1d63f028f2 Mon Sep 17 00:00:00 2001 From: Dariusz Kobylarz Date: Tue, 4 Nov 2014 09:53:43 -0800 Subject: [PATCH 01/68] fixed MLlib Naive-Bayes java example bug the filter tests Double objects by references whereas it should test their values Author: Dariusz Kobylarz Closes #3081 from dkobylarz/master and squashes the following commits: 5d43a39 [Dariusz Kobylarz] naive bayes example update a304b93 [Dariusz Kobylarz] fixed MLlib Naive-Bayes java example bug --- docs/mllib-naive-bayes.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 7f9d4c6563944..d5b044d94fdd7 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel = return new Tuple2(model.predict(p.features()), p.label()); } }); -double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { - return pl._1() == pl._2(); + return pl._1().equals(pl._2()); } - }).count() / test.count(); + }).count() / (double) test.count(); {% endhighlight %} From f90ad5d426cb726079c490a9bb4b1100e2b4e602 Mon Sep 17 00:00:00 2001 From: Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> Date: Tue, 4 Nov 2014 09:57:03 -0800 Subject: [PATCH 02/68] [Spark-4060] [MLlib] exposing special rdd functions to the public Author: Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> Closes #2907 from numbnut/master and squashes the following commits: 7f7c767 [Niklas Wilcke] [Spark-4060] [MLlib] exposing special rdd functions to the public, #2907 --- .../spark/mllib/evaluation/AreaUnderCurve.scala | 2 +- .../org/apache/spark/mllib/rdd/RDDFunctions.scala | 11 ++++++----- .../scala/org/apache/spark/mllib/rdd/SlidingRDD.scala | 5 +++-- .../apache/spark/mllib/rdd/RDDFunctionsSuite.scala | 6 +++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 7858ec602483f..078fbfbe4f0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve { */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( - seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), combOp = _ + _ ) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b5e403bc8c14d..57c0768084e41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -28,8 +29,8 @@ import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. */ -private[mllib] -class RDDFunctions[T: ClassTag](self: RDD[T]) { +@DeveloperApi +class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding @@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Seq[T]] = { + def sliding(windowSize: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") if (windowSize == 1) { - self.map(Seq(_)) + self.map(Array(_)) } else { new SlidingRDD[T](self, windowSize) } @@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } -private[mllib] +@DeveloperApi object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index dd80782c0f001..35e81fcb3de0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] */ private[mllib] class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) - extends RDD[Seq[T]](parent) { + extends RDD[Array[T]](parent) { require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") - override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) .sliding(windowSize) .withPartial(false) + .map(_.toArray) } override def getPreferredLocations(split: Partition): Seq[String] = diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 27a19f793242b..4ef67a40b9f49 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.size === data.length) - val sliding = rdd.sliding(3) - val expected = data.flatMap(x => x).sliding(3).toList - assert(sliding.collect().toList === expected) + val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) + val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) + assert(sliding === expected) } test("treeAggregate") { From 5e73138a0152b78380b3f1def4b969b58e70dd11 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 16:15:38 -0800 Subject: [PATCH 03/68] [SPARK-2938] Support SASL authentication in NettyBlockTransferService Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService --- .../org/apache/spark/SecurityManager.scala | 23 ++- .../scala/org/apache/spark/SparkConf.scala | 6 + .../scala/org/apache/spark/SparkContext.scala | 2 + .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../org/apache/spark/SparkSaslClient.scala | 147 --------------- .../org/apache/spark/SparkSaslServer.scala | 176 ------------------ .../org/apache/spark/executor/Executor.scala | 1 + .../netty/NettyBlockTransferService.scala | 28 ++- .../apache/spark/network/nio/Connection.scala | 5 +- .../spark/network/nio/ConnectionManager.scala | 7 +- .../apache/spark/storage/BlockManager.scala | 45 +++-- .../NettyBlockTransferSecuritySuite.scala | 161 ++++++++++++++++ .../network/nio/ConnectionManagerSuite.scala | 6 +- .../BlockManagerReplicationSuite.scala | 2 + .../spark/storage/BlockManagerSuite.scala | 4 +- docs/security.md | 1 - .../spark/network/TransportContext.java | 15 +- .../spark/network/client/TransportClient.java | 11 +- .../client/TransportClientBootstrap.java | 32 ++++ .../client/TransportClientFactory.java | 64 +++++-- .../spark/network/server/NoOpRpcHandler.java | 2 +- .../spark/network/server/RpcHandler.java | 19 +- .../server/TransportRequestHandler.java | 1 + .../spark/network/util/TransportConf.java | 3 + .../network/sasl/SaslClientBootstrap.java | 74 ++++++++ .../spark/network/sasl/SaslMessage.java | 74 ++++++++ .../spark/network/sasl/SaslRpcHandler.java | 97 ++++++++++ .../spark/network/sasl/SecretKeyHolder.java | 35 ++++ .../spark/network/sasl/SparkSaslClient.java | 138 ++++++++++++++ .../spark/network/sasl/SparkSaslServer.java | 170 +++++++++++++++++ .../shuffle/ExternalShuffleBlockHandler.java | 2 +- .../shuffle/ExternalShuffleClient.java | 15 +- .../spark/network/shuffle/ShuffleClient.java | 11 +- .../network/sasl/SaslIntegrationSuite.java | 172 +++++++++++++++++ .../spark/network/sasl/SparkSaslSuite.java | 89 +++++++++ .../ExternalShuffleIntegrationSuite.java | 7 +- .../streaming/ReceivedBlockHandlerSuite.scala | 1 + 37 files changed, 1257 insertions(+), 392 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslServer.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala create mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377e..dee935ffad51f 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +342,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + override def getSaslUser(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSaslUser() + } + + override def getSecretKey(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSecretKey() + } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ad0a9017afead..4c6c86c7bad78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 40444c237b738..3cdaa6a9cc8a8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e2f13accdfab5..45e9d7f243e96 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -276,7 +276,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf) + new NettyBlockTransferService(conf, securityManager) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -285,6 +285,7 @@ object SparkEnv extends Logging { "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index a954fcc0c31fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index 7c2afb364661f..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), UTF_8) - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), UTF_8).toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8b095e23f32ff..abc1dd0be6237 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -86,6 +86,7 @@ private[spark] class Executor( conf, executorId, slaveHostname, port, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 1c4327cf13b51..0d1fc81d2a16f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,13 +17,15 @@ package org.apache.spark.network.netty +import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer @@ -33,18 +35,30 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService { + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - val serializer = new JavaSerializer(conf) + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) - clientFactory = transportContext.createClientFactory() + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) + } + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() logInfo("Server created on " + server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 4f6f5e235811d..c2d9578be7ebb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,12 +23,13 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} + private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8408b75bb4d65..f198aa8564a54 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils import scala.util.Try @@ -600,7 +601,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -634,7 +635,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -778,7 +779,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5f5dd0dc1c63f..655d16c65c8b5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -57,6 +57,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -69,8 +75,6 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -102,22 +106,16 @@ private[spark] class BlockManager( + " switch to sort-based shuffle.") } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external // service, or just our own Executor's BlockManager. - private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { - BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) - } else { - blockManagerId - } + private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val appId = conf.get("spark.app.id", "unknown-app-id") - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) } else { blockTransferService } @@ -150,8 +148,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -176,10 +172,27 @@ private[spark] class BlockManager( } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) // Register Executors' configuration with the local shuffle service, if one should exist. diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala new file mode 100644 index 0000000000000..bed0ed9d713dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio._ +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.util.{Failure, Success, Try} + +import org.apache.commons.io.IOUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.{SecurityManager, SparkConf} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} + +class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { + test("security default off") { + testConnection(new SparkConf, new SparkConf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on same password") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on mismatch password") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate.secret", "bad") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Mismatched response") + } + } + + test("security mismatch auth off on server") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "false") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC + } + } + + test("security mismatch auth off on client") { + val conf0 = new SparkConf() + .set("spark.authenticate", "false") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "true") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Expected SaslMessage") + } + } + + test("security mismatch app ids") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.app.id", "other-id") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") + } + } + + /** + * Creates two servers with different configurations and sees if they can talk. + * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed + * properly. We will throw an out-of-band exception if something other than that goes wrong. + */ + private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = { + val blockManager = mock[BlockDataManager] + val blockId = ShuffleBlockId(0, 1, 2) + val blockString = "Hello, world!" + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + + val securityManager0 = new SecurityManager(conf0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0) + exec0.init(blockManager) + + val securityManager1 = new SecurityManager(conf1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1) + exec1.init(blockManager) + + val result = fetchBlock(exec0, exec1, "1", blockId) match { + case Success(buf) => + IOUtils.toString(buf.createInputStream()) should equal(blockString) + buf.release() + Success() + case Failure(t) => + Failure(t) + } + exec0.close() + exec1.close() + result + } + + /** Synchronously fetches a single block, acting as the given executor fetching from another. */ + private def fetchBlock( + self: BlockTransferService, + from: BlockTransferService, + execId: String, + blockId: BlockId): Try[ManagedBuffer] = { + + val promise = Promise[ManagedBuffer]() + + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + promise.failure(exception) + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + promise.success(data.retain()) + } + }) + + Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + promise.future.value.get + } +} + diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index b70734dfe37cf..716f875d30b8a 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") + conf.set("spark.app.id", "app-id") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) var numReceivedMessages = 0 @@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite { test("security mismatch password") { val conf = new SparkConf conf.set("spark.authenticate", "true") + conf.set("spark.app.id", "app-id") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) @@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite { None }) - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") + val badconf = conf.clone.set("spark.authenticate.secret", "bad") val badsecurityManager = new SecurityManager(badconf) val managerServer = new ConnectionManager(0, badconf, badsecurityManager) var numReceivedServerMessages = 0 diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c6d7105592096..1461fa69db90d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + store.initialize("app-id") allStores += store store } @@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 715b740b857b2..0782876c8e3c6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + manager.initialize("app-id") + manager } before { diff --git a/docs/security.md b/docs/security.md index ec0523184d665..1e206a139fb72 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index a271841e4e56c..5bc6e5a2418a9 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -64,8 +68,17 @@ public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.newArrayList()); } /** Create a server which will attempt to bind to a specific port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 01c143fff423c..a08cee02dd576 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,10 +19,9 @@ import java.io.Closeable; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; @@ -186,4 +185,12 @@ public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000000..65e8020e34121 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 0b4a1d8286407..1723fed307257 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -21,10 +21,14 @@ import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -40,6 +44,7 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -47,22 +52,29 @@ * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; + private final List clientBootstraps; private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { - this.context = context; + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); @@ -72,9 +84,12 @@ public TransportClientFactory(TransportContext context) { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ @@ -104,17 +119,18 @@ public TransportClient createClient(String remoteHost, int remotePort) { // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference client = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new RuntimeException( @@ -123,15 +139,35 @@ public void initChannel(SocketChannel ch) { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection -- in the event that two threads raced to create a client, we will + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); + } + long postBootstrap = System.currentTimeMillis(); + + // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); return oldClient; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 5a3f003726fc1..1502b7489e864 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -21,7 +21,7 @@ import org.apache.spark.network.client.TransportClient; /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler implements RpcHandler { +public class NoOpRpcHandler extends RpcHandler { private final StreamManager streamManager; public NoOpRpcHandler() { diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2369dc6203944..2ba92a40f8b0a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,22 +23,33 @@ /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. */ - StreamManager getStreamManager(); + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 17fe9001b35cc..1580180cc17e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -86,6 +86,7 @@ public void channelUnregistered() { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index a68f38e0e94c9..823790dd3c66f 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -55,4 +55,7 @@ public int connectionTimeoutMs() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 0000000000000..7bc91e375371f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..5b77e18c26bf4 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + // tag + appIdLength + appId + payloadLength + payload + return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + byte[] idBytes = appId.getBytes(Charsets.UTF_8); + buf.writeInt(idBytes.length); + buf.writeBytes(idBytes); + buf.writeInt(payload.length); + buf.writeBytes(payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else"); + } + + int idLength = buf.readInt(); + byte[] idBytes = new byte[idLength]; + buf.readBytes(idBytes); + + int payloadLength = buf.readInt(); + byte[] payload = new byte[payloadLength]; + buf.readBytes(payload); + + return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000000..3777a18e33f78 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000000..81d5766794688 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000000..72ba737b998bc --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + logger.info("Realm callback"); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000000..2c0ce40c75e80 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a9dff31decc83..cd3fea85b19a4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -41,7 +41,7 @@ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- * level shuffle block. */ -public class ExternalShuffleBlockHandler implements RpcHandler { +public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); private final ExternalShuffleBlockManager blockManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6bbabc44b958b..b0b19ba67bddc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.Closeable; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,15 +34,20 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient implements ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportClientFactory clientFactory; - private final String appId; - public ExternalShuffleClient(TransportConf conf, String appId) { + private String appId; + + public ExternalShuffleClient(TransportConf conf) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); this.clientFactory = context.createClientFactory(); + } + + @Override + public void init(String appId) { this.appId = appId; } @@ -55,6 +58,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener) { + assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { TransportClient client = clientFactory.createClient(host, port); @@ -82,6 +86,7 @@ public void registerWithShuffleServer( int port, String execId, ExecutorShuffleInfo executorInfo) { + assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index d46a562394557..f72ab40690d0d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -20,7 +20,14 @@ import java.io.Closeable; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public interface ShuffleClient extends Closeable { +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + /** * Fetch a sequence of blocks from a remote node asynchronously, * @@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - public void fetchBlocks( + public abstract void fetchBlocks( String host, int port, String execId, diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000000..84781207861ed --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000000..67a07f38eb5a0 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b3bcf5fd68e73..bc101f53844d5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,8 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override @@ -164,6 +165,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } + client.close(); return res; } @@ -265,7 +267,8 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ad1a6f01b3a57..0f27f55fec4f3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -74,6 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr)) + blockManager.initialize("app-id") tempDirectory = Files.createTempDir() manualClock.setTime(0) From 515abb9afa2d6b58947af6bb079a493b49d315ca Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 4 Nov 2014 18:14:28 -0800 Subject: [PATCH 04/68] [SQL] Add String option for DSL AS Author: Michael Armbrust Closes #3097 from marmbrus/asString and squashes the following commits: 6430520 [Michael Armbrust] Add String option for DSL AS --- .../main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3314e15477016..31dc5a58e68e5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -110,7 +110,8 @@ package object dsl { def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) - def as(s: Symbol) = Alias(expr, s.name)() + def as(alias: String) = Alias(expr, alias)() + def as(alias: Symbol) = Alias(expr, alias.name)() } trait ExpressionConversions { From c8abddc5164d8cf11cdede6ab3d5d1ea08028708 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 4 Nov 2014 21:35:52 -0800 Subject: [PATCH 05/68] [SPARK-3964] [MLlib] [PySpark] add Hypothesis test Python API ``` pyspark.mllib.stat.StatisticschiSqTest(observed, expected=None) :: Experimental :: If `observed` is Vector, conduct Pearson's chi-squared goodness of fit test of the observed data against the expected distribution, or againt the uniform distribution (by default), with each category having an expected frequency of `1 / len(observed)`. (Note: `observed` cannot contain negative values) If `observed` is matrix, conduct Pearson's independence test on the input contingency matrix, which cannot contain negative entries or columns or rows that sum up to 0. If `observed` is an RDD of LabeledPoint, conduct Pearson's independence test for every feature against the label across the input RDD. For each feature, the (feature, label) pairs are converted into a contingency matrix for which the chi-squared statistic is computed. All label and feature values must be categorical. :param observed: it could be a vector containing the observed categorical counts/relative frequencies, or the contingency matrix (containing either counts or relative frequencies), or an RDD of LabeledPoint containing the labeled dataset with categorical features. Real-valued features will be treated as categorical for each distinct value. :param expected: Vector containing the expected categorical counts/relative frequencies. `expected` is rescaled if the `expected` sum differs from the `observed` sum. :return: ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, the method used, and the null hypothesis. ``` Author: Davies Liu Closes #3091 from davies/his and squashes the following commits: 145d16c [Davies Liu] address comments 0ab0764 [Davies Liu] fix float 5097d54 [Davies Liu] add Hypothesis test Python API --- docs/mllib-statistics.md | 40 +++++ .../mllib/api/python/PythonMLLibAPI.scala | 26 ++++ python/pyspark/mllib/common.py | 7 +- python/pyspark/mllib/linalg.py | 13 +- python/pyspark/mllib/stat.py | 137 +++++++++++++++++- 5 files changed, 219 insertions(+), 4 deletions(-) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 10a5131c07414..ca8c29218f52d 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) { {% endhighlight %} +
+[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors, Matrices +from pyspark.mllib.regresssion import LabeledPoint +from pyspark.mllib.stat import Statistics + +sc = SparkContext() + +vec = Vectors.dense(...) # a vector composed of the frequencies of events + +# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +# the test runs against a uniform distribution. +goodnessOfFitTestResult = Statistics.chiSqTest(vec) +print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + +mat = Matrices.dense(...) # a contingency matrix + +# conduct Pearson's independence test on the input contingency matrix +independenceTestResult = Statistics.chiSqTest(mat) +print independenceTestResult # summary of the test including the p-value, degrees of freedom... + +obs = sc.parallelize(...) # LabeledPoint(feature, label) . + +# The contingency table is constructed from an RDD of LabeledPoint and used to conduct +# the independence test. Returns an array containing the ChiSquaredTestResult for every feature +# against the label. +featureTestResults = Statistics.chiSqTest(obs) + +for i, result in enumerate(featureTestResults): + print "Column $d:" % (i + 1) + print result +{% endhighlight %} +
+ ## Random data generation diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 65b98a8ceea55..d832ae34b55e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -454,6 +455,31 @@ class PythonMLLibAPI extends Serializable { Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } + /** + * Java stub for mllib Statistics.chiSqTest() + */ + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + if (expected == null) { + Statistics.chiSqTest(observed) + } else { + Statistics.chiSqTest(observed, expected) + } + } + + /** + * Java stub for mllib Statistics.chiSqTest(observed: Matrix) + */ + def chiSqTest(observed: Matrix): ChiSqTestResult = { + Statistics.chiSqTest(observed) + } + + /** + * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint]) + */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = { + Statistics.chiSqTest(data.rdd) + } + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark private def getCorrNameOrDefault(method: String) = { if (method == null) CorrelationNames.defaultCorrName else method diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index dbe5f698b7345..c6149fe391ec8 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -98,8 +98,13 @@ def _java2py(sc, r): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) - elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.SerDe.dumps(r) + except Py4JJavaError: + pass # not pickable if isinstance(r, bytearray): r = PickleSerializer().loads(str(r)) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index c0c3dff31e7f8..e35202dca0acc 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -33,7 +33,7 @@ IntegerType, ByteType, Row -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): @@ -578,6 +578,8 @@ class DenseMatrix(Matrix): def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) assert len(values) == numRows * numCols + if not isinstance(values, array.array): + values = array.array('d', values) self.values = values def __reduce__(self): @@ -596,6 +598,15 @@ def toArray(self): return np.reshape(self.values, (self.numRows, self.numCols), order='F') +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 15f0652f833d7..0700f8a8e5a8e 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,11 +19,12 @@ Python package for statistical functions in MLlib. """ +from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import Matrix, _convert_to_vector -__all__ = ['MultivariateStatisticalSummary', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -51,6 +52,54 @@ def min(self): return self.call("min").toArray() +class ChiSqTestResult(JavaModelWrapper): + """ + :: Experimental :: + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() + + class Statistics(object): @staticmethod @@ -135,6 +184,90 @@ def corr(x, y=None, method=None): else: return callMLlibFunc("corr", x.map(float), y.map(float), method) + @staticmethod + def chiSqTest(observed, expected=None): + """ + :: Experimental :: + + If `observed` is Vector, conduct Pearson's chi-squared goodness + of fit test of the observed data against the expected distribution, + or againt the uniform distribution (by default), with each category + having an expected frequency of `1 / len(observed)`. + (Note: `observed` cannot contain negative values) + + If `observed` is matrix, conduct Pearson's independence test on the + input contingency matrix, which cannot contain negative entries or + columns or rows that sum up to 0. + + If `observed` is an RDD of LabeledPoint, conduct Pearson's independence + test for every feature against the label across the input RDD. + For each feature, the (feature, label) pairs are converted into a + contingency matrix for which the chi-squared statistic is computed. + All label and feature values must be categorical. + + :param observed: it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + :param expected: Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + :return: ChiSquaredTest object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + >>> from pyspark.mllib.linalg import Vectors, Matrices + >>> observed = Vectors.dense([4, 6, 5]) + >>> pearson = Statistics.chiSqTest(observed) + >>> print pearson.statistic + 0.4 + >>> pearson.degreesOfFreedom + 2 + >>> print round(pearson.pValue, 4) + 0.8187 + >>> pearson.method + u'pearson' + >>> pearson.nullHypothesis + u'observed follows the same distribution as expected.' + + >>> observed = Vectors.dense([21, 38, 43, 80]) + >>> expected = Vectors.dense([3, 5, 7, 20]) + >>> pearson = Statistics.chiSqTest(observed, expected) + >>> print round(pearson.pValue, 4) + 0.0027 + + >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + >>> print round(chi.statistic, 4) + 21.9958 + + >>> from pyspark.mllib.regression import LabeledPoint + >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] + >>> rdd = sc.parallelize(data, 4) + >>> chi = Statistics.chiSqTest(rdd) + >>> print chi[0].statistic + 0.75 + >>> print chi[1].statistic + 1.5 + """ + if isinstance(observed, RDD): + jmodels = callMLlibFunc("chiSqTest", observed) + return [ChiSqTestResult(m) for m in jmodels] + + if isinstance(observed, Matrix): + jmodel = callMLlibFunc("chiSqTest", observed) + else: + if expected and len(expected) != len(observed): + raise ValueError("`expected` should have same length with `observed`") + jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) + return ChiSqTestResult(jmodel) + def _test(): import doctest From 5f13759d3642ea5b58c12a756e7125ac19aff10e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 5 Nov 2014 01:21:53 -0800 Subject: [PATCH 06/68] [SPARK-4029][Streaming] Update streaming driver to reliably save and recover received block metadata on driver failures As part of the initiative of preventing data loss on driver failure, this JIRA tracks the sub task of modifying the streaming driver to reliably save received block metadata, and recover them on driver restart. This was solved by introducing a `ReceivedBlockTracker` that takes all the responsibility of managing the metadata of received blocks (i.e. `ReceivedBlockInfo`, and any actions on them (e.g, allocating blocks to batches, etc.). All actions to block info get written out to a write ahead log (using `WriteAheadLogManager`). On recovery, all the actions are replaying to recreate the pre-failure state of the `ReceivedBlockTracker`, which include the batch-to-block allocations and the unallocated blocks. Furthermore, the `ReceiverInputDStream` was modified to create `WriteAheadLogBackedBlockRDD`s when file segment info is present in the `ReceivedBlockInfo`. After recovery of all the block info (through recovery `ReceivedBlockTracker`), the `WriteAheadLogBackedBlockRDD`s gets recreated with the recovered info, and jobs submitted. The data of the blocks gets pulled from the write ahead logs, thanks to the segment info present in the `ReceivedBlockInfo`. This is still a WIP. Things that are missing here are. - *End-to-end integration tests:* Unit tests that tests the driver recovery, by killing and restarting the streaming context, and verifying all the input data gets processed. This has been implemented but not included in this PR yet. A sneak peek of that DriverFailureSuite can be found in this PR (on my personal repo): https://github.com/tdas/spark/pull/25 I can either include it in this PR, or submit that as a separate PR after this gets in. - *WAL cleanup:* Cleaning up the received data write ahead log, by calling `ReceivedBlockHandler.cleanupOldBlocks`. This is being worked on. Author: Tathagata Das Closes #3026 from tdas/driver-ha-rbt and squashes the following commits: a8009ed [Tathagata Das] Added comment 1d704bb [Tathagata Das] Enabled storing recovered WAL-backed blocks to BM 2ee2484 [Tathagata Das] More minor changes based on PR 47fc1e3 [Tathagata Das] Addressed PR comments. 9a7e3e4 [Tathagata Das] Refactored ReceivedBlockTracker API a bit to make things a little cleaner for users of the tracker. af63655 [Tathagata Das] Minor changes. fce2b21 [Tathagata Das] Removed commented lines 59496d3 [Tathagata Das] Changed class names, made allocation more explicit and added cleanup 19aec7d [Tathagata Das] Fixed casting bug. f66d277 [Tathagata Das] Fix line lengths. cda62ee [Tathagata Das] Added license 25611d6 [Tathagata Das] Minor changes before submitting PR 7ae0a7fb [Tathagata Das] Transferred changes from driver-ha-working branch --- .../dstream/ReceiverInputDStream.scala | 69 +++-- .../rdd/WriteAheadLogBackedBlockRDD.scala | 3 +- .../streaming/scheduler/JobGenerator.scala | 21 +- .../scheduler/ReceivedBlockTracker.scala | 230 +++++++++++++++++ .../streaming/scheduler/ReceiverTracker.scala | 98 ++++--- .../streaming/BasicOperationsSuite.scala | 19 +- .../streaming/ReceivedBlockTrackerSuite.scala | 242 ++++++++++++++++++ .../WriteAheadLogBackedBlockRDDSuite.scala | 4 +- 8 files changed, 597 insertions(+), 89 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index bb47d373de63d..3e67161363e50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -17,15 +17,14 @@ package org.apache.spark.streaming.dstream -import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, BlockManagerBasedStoreResult, Receiver} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.SparkException /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -40,9 +39,6 @@ import org.apache.spark.SparkException abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** Keeps all received blocks information */ - private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] - /** This is an unique identifier for the network input stream. */ val id = ssc.getNewReceiverStreamId() @@ -58,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont def stop() {} - /** Ask ReceiverInputTracker for received data blocks and generates RDDs with them. */ + /** + * Generates RDDs with blocks received by the receiver of this stream. */ override def compute(validTime: Time): Option[RDD[T]] = { - // If this is called for any time before the start time of the context, - // then this returns an empty RDD. This may happen when recovering from a - // master failure - if (validTime >= graph.startTime) { - val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) - receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map { _.blockStoreResult.blockId.asInstanceOf[BlockId] } - Some(new BlockRDD[T](ssc.sc, blockIds)) - } else { - Some(new BlockRDD[T](ssc.sc, Array.empty)) - } - } + val blockRDD = { - /** Get information on received blocks. */ - private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) + if (validTime < graph.startTime) { + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // driver failure without any write ahead log to recover pre-failure data. + new BlockRDD[T](ssc.sc, Array.empty) + } else { + // Otherwise, ask the tracker for all the blocks that have been allocated to this stream + // for this batch + val blockInfos = + ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty) + val blockStoreResults = blockInfos.map { _.blockStoreResult } + val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Check whether all the results are of the same type + val resultTypes = blockStoreResults.map { _.getClass }.distinct + if (resultTypes.size > 1) { + logWarning("Multiple result types in block information, WAL information will be ignored.") + } + + // If all the results are of type WriteAheadLogBasedStoreResult, then create + // WriteAheadLogBackedBlockRDD else create simple BlockRDD. + if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { + val logSegments = blockStoreResults.map { + _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + }.toArray + // Since storeInBlockManager = false, the storage level does not matter. + new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, + blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + } else { + new BlockRDD[T](ssc.sc, blockIds) + } + } + } + Some(blockRDD) } /** @@ -86,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ private[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) - receivedBlockInfo --= oldReceivedBlocks.keys - logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 23295bf658712..dd1e96334952f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -48,7 +48,6 @@ class WriteAheadLogBackedBlockRDDPartition( * If it does not find them, it looks up the corresponding file segment. * * @param sc SparkContext - * @param hadoopConfig Hadoop configuration * @param blockIds Ids of the blocks that contains this RDD's data * @param segments Segments in write ahead logs that contain this RDD's data * @param storeInBlockManager Whether to store in the block manager after reading from the segment @@ -58,7 +57,6 @@ class WriteAheadLogBackedBlockRDDPartition( private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, - @transient hadoopConfig: Configuration, @transient blockIds: Array[BlockId], @transient segments: Array[WriteAheadLogFileSegment], storeInBlockManager: Boolean, @@ -71,6 +69,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( s"the same as number of segments (${segments.length}})!") // Hadoop configuration is not serializable, so broadcast it as a serializable. + @transient private val hadoopConfig = sc.hadoopConfiguration private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) override def getPartitions: Array[Partition] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 7d73ada12d107..39b66e1130768 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Wait until all the received blocks in the network input tracker has // been consumed by network input DStreams, and jobs have been generated with them logInfo("Waiting for all received blocks to be consumed for job generation") - while(!hasTimedOut && jobScheduler.receiverTracker.hasMoreReceivedBlockIds) { + while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) { Thread.sleep(pollTime) } logInfo("Waited for all received blocks to be consumed for job generation") @@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - Try(graph.generateJobs(time)) match { + // Set the SparkEnv in this thread, so that job generation code can access the environment + // Example: BlockRDDs are created in this thread, and it needs to access BlockManager + // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. + SparkEnv.set(ssc.env) + Try { + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch + graph.generateJobs(time) // generate jobs using allocated block + } match { case Success(jobs) => - val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => - val streamId = stream.id - val receivedBlockInfo = stream.getReceivedBlockInfo(time) - (streamId, receivedBlockInfo) - }.toMap - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + val receivedBlockInfos = + jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) + jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala new file mode 100644 index 0000000000000..5f5e1909908d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait representing any event in the ReceivedBlockTracker that updates its state. */ +private[streaming] sealed trait ReceivedBlockTrackerLogEvent + +private[streaming] case class BlockAdditionEvent(receivedBlockInfo: ReceivedBlockInfo) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: AllocatedBlocks) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchCleanupEvent(times: Seq[Time]) + extends ReceivedBlockTrackerLogEvent + + +/** Class representing the blocks of all the streams allocated to a batch */ +private[streaming] +case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { + def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = { + streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty) + } +} + +/** + * Class that keep track of all the received blocks, and allocate them to batches + * when required. All actions taken by this class can be saved to a write ahead log + * (if a checkpoint directory has been provided), so that the state of the tracker + * (received blocks and block-to-batch allocations) can be recovered after driver failure. + * + * Note that when any instance of this class is created with a checkpoint directory, + * it will try reading events from logs in the directory. + */ +private[streaming] class ReceivedBlockTracker( + conf: SparkConf, + hadoopConf: Configuration, + streamIds: Seq[Int], + clock: Clock, + checkpointDirOption: Option[String]) + extends Logging { + + private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] + + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] + private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] + + private val logManagerRollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + private val logManagerOption = checkpointDirOption.map { checkpointDir => + new WriteAheadLogManager( + ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), + hadoopConf, + rollingIntervalSecs = logManagerRollingIntervalSecs, + callerName = "ReceivedBlockHandlerMaster", + clock = clock + ) + } + + private var lastAllocatedBatchTime: Time = null + + // Recover block information from write ahead logs + recoverFromWriteAheadLogs() + + /** Add received block. This event will get written to the write ahead log (if enabled). */ + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + try { + writeToLog(BlockAdditionEvent(receivedBlockInfo)) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + true + } catch { + case e: Exception => + logError(s"Error adding block $receivedBlockInfo", e) + false + } + } + + /** + * Allocate all unallocated blocks to the given batch. + * This event will get written to the write ahead log (if enabled). + */ + def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { + if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { + val streamIdToBlocks = streamIds.map { streamId => + (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + }.toMap + val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) + writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + timeToAllocatedBlocks(batchTime) = allocatedBlocks + lastAllocatedBatchTime = batchTime + allocatedBlocks + } else { + throw new SparkException(s"Unexpected allocation of blocks, " + + s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + } + } + + /** Get the blocks allocated to the given batch. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = synchronized { + timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks }.getOrElse(Map.empty) + } + + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + timeToAllocatedBlocks.get(batchTime).map { + _.getBlocksOfStream(streamId) + }.getOrElse(Seq.empty) + } + } + + /** Check if any blocks are left to be allocated to batches. */ + def hasUnallocatedReceivedBlocks: Boolean = synchronized { + !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty) + } + + /** + * Get blocks that have been added but not yet allocated to any batch. This method + * is primarily used for testing. + */ + def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized { + getReceivedBlockQueue(streamId).toSeq + } + + /** Clean up block information of old batches. */ + def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + assert(cleanupThreshTime.milliseconds < clock.currentTime()) + val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq + logInfo("Deleting batches " + timesToCleanup) + writeToLog(BatchCleanupEvent(timesToCleanup)) + timeToAllocatedBlocks --= timesToCleanup + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + log + } + + /** Stop the block tracker. */ + def stop() { + logManagerOption.foreach { _.stop() } + } + + /** + * Recover all the tracker actions from the write ahead logs to recover the state (unallocated + * and allocated block info) prior to failure. + */ + private def recoverFromWriteAheadLogs(): Unit = synchronized { + // Insert the recovered block information + def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { + logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + + // Insert the recovered block-to-batch allocations and clear the queue of received blocks + // (when the blocks were originally allocated to the batch, the queue must have been cleared). + def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { + logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + + s"${allocatedBlocks.streamIdToAllocatedBlocks}") + streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + lastAllocatedBatchTime = batchTime + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + } + + // Cleanup the batch allocations + def cleanupBatches(batchTimes: Seq[Time]) { + logTrace(s"Recovery: Cleaning up batches $batchTimes") + timeToAllocatedBlocks --= batchTimes + } + + logManagerOption.foreach { logManager => + logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") + logManager.readFromLog().foreach { byteBuffer => + logTrace("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + case BlockAdditionEvent(receivedBlockInfo) => + insertAddedBlock(receivedBlockInfo) + case BatchAllocationEvent(time, allocatedBlocks) => + insertAllocatedBatch(time, allocatedBlocks) + case BatchCleanupEvent(batchTimes) => + cleanupBatches(batchTimes) + } + } + } + } + + /** Write an update to the tracker to the write ahead log */ + private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + logDebug(s"Writing to log $record") + logManagerOption.foreach { logManager => + logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } + } + + /** Get the queue of received blocks belonging to a particular stream */ + private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { + streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue) + } +} + +private[streaming] object ReceivedBlockTracker { + def checkpointDirToLogDir(checkpointDir: String): String = { + new Path(checkpointDir, "receivedBlockMetadata").toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d696563bcee83..1c3984d968d20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,15 +17,16 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} + +import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import akka.actor._ -import org.apache.spark.{SerializableWritable, Logging, SparkEnv, SparkException} + +import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} -import org.apache.spark.util.AkkaUtils /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -48,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * This class manages the execution of the receivers of NetworkInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. + * + * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging { - val receiverInputStreams = ssc.graph.getReceiverInputStreams() - val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverLauncher() - val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] - val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - val timeout = AkkaUtils.askTimeout(ssc.conf) - val listenerBus = ssc.scheduler.listenerBus + private val receiverInputStreams = ssc.graph.getReceiverInputStreams() + private val receiverInputStreamIds = receiverInputStreams.map { _.id } + private val receiverExecutor = new ReceiverLauncher() + private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] + private val receivedBlockTracker = new ReceivedBlockTracker( + ssc.sparkContext.conf, + ssc.sparkContext.hadoopConfiguration, + receiverInputStreamIds, + ssc.scheduler.clock, + Option(ssc.checkpointDir) + ) + private val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped - var actor: ActorRef = null - var currentTime: Time = null + private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ def start() = synchronized { @@ -75,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { if (!receiverInputStreams.isEmpty) { actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), "ReceiverTracker") - receiverExecutor.start() + if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } @@ -84,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def stop() = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop() // Finally, stop the actor ssc.env.actorSystem.stop(actor) actor = null + receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { - val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) - logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") - receivedBlockInfo.toArray + /** Allocate all unallocated blocks to the given batch. */ + def allocateBlocksToBatch(batchTime: Time): Unit = { + if (receiverInputStreams.nonEmpty) { + receivedBlockTracker.allocateBlocksToBatch(batchTime) + } + } + + /** Get the blocks for the given batch and all input streams. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = { + receivedBlockTracker.getBlocksOfBatch(batchTime) } - private def getReceivedBlockInfoQueue(streamId: Int) = { - receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) + } + } + + /** Clean up metadata older than the given threshold time */ + def cleanupOldMetadata(cleanupThreshTime: Time) { + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) } /** Register a receiver */ - def registerReceiver( + private def registerReceiver( streamId: Int, typ: String, host: String, receiverActor: ActorRef, sender: ActorRef ) { - if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + if (!receiverInputStreamIds.contains(streamId)) { + throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) } /** Deregister a receiver */ - def deregisterReceiver(streamId: Int, message: String, error: String) { + private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) @@ -131,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -141,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Add new blocks for the given stream */ - def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { - getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockStoreResult.blockId) + private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { + receivedBlockTracker.addBlock(receivedBlockInfo) } /** Report error sent by a receiver */ - def reportError(streamId: Int, message: String, error: String) { + private def reportError(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(lastErrorMessage = message, lastError = error) @@ -157,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -167,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Check if any blocks are left to be processed */ - def hasMoreReceivedBlockIds: Boolean = { - !receivedBlockInfo.values.forall(_.isEmpty) + def hasUnallocatedBlocks: Boolean = { + receivedBlockTracker.hasUnallocatedReceivedBlocks } /** Actor to receive messages from the receivers. */ @@ -178,8 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true case AddBlock(receivedBlockInfo) => - addBlocks(receivedBlockInfo) - sender ! true + sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -194,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { + SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") @@ -267,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6c8bb50145367..dbab685dc3511 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,18 +17,19 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.SparkContext._ +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.language.existentials +import scala.reflect.ClassTag import util.ManualClock -import org.apache.spark.{SparkException, SparkConf} -import org.apache.spark.streaming.dstream.{WindowedDStream, DStream} -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import scala.collection.mutable +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} class BasicOperationsSuite extends TestSuiteBase { test("map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala new file mode 100644 index 0000000000000..fd9c97f551c62 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} +import scala.util.Random + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogSuite._ +import org.apache.spark.util.Utils + +class ReceivedBlockTrackerSuite + extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + + val hadoopConf = new Configuration() + val akkaTimeout = 10 seconds + val streamId = 1 + + var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() + var checkpointDirectory: File = null + + before { + checkpointDirectory = Files.createTempDir() + } + + after { + allReceivedBlockTrackers.foreach { _.stop() } + if (checkpointDirectory != null && checkpointDirectory.exists()) { + FileUtils.deleteDirectory(checkpointDirectory) + checkpointDirectory = null + } + } + + test("block addition, and block to batch allocation") { + val receivedBlockTracker = createTracker(enableCheckpoint = false) + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + val blockInfos = generateBlockInfos() + blockInfos.map(receivedBlockTracker.addBlock) + + // Verify added blocks are unallocated blocks + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Allocate the blocks to a batch and verify that all of them have been allocated + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty + + // Allocate no blocks to another batch + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + + // Verify that batch 2 cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(2) + } + + // Verify that older batches cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(1) + } + } + + test("block addition, block to batch allocation and cleanup with write ahead log") { + val manualClock = new ManualClock + conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) + + // Set the time increment level to twice the rotation interval so that every increment creates + // a new log file + val timeIncrementMillis = 2000L + def incrementTime() { + manualClock.addToTime(timeIncrementMillis) + } + + // Generate and add blocks to the given tracker + def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = { + val blockInfos = generateBlockInfos() + blockInfos.map(tracker.addBlock) + blockInfos + } + + // Print the data present in the log ahead files in the log directory + def printLogFiles(message: String) { + val fileContents = getWriteAheadLogFiles().map { file => + (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}") + }.mkString("\n") + logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") + } + + // Start tracker and add blocks + val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + val blockInfos1 = addBlockInfos(tracker1) + tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Verify whether write ahead log has correct contents + val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent) + getWrittenLogData() shouldEqual expectedWrittenData1 + getWriteAheadLogFiles() should have size 1 + + // Restart tracker and verify recovered list of unallocated blocks + incrementTime() + val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Allocate blocks to batch and verify whether the unallocated blocks got allocated + val batchTime1 = manualClock.currentTime + tracker2.allocateBlocksToBatch(batchTime1) + tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + + // Add more blocks and allocate to another batch + incrementTime() + val batchTime2 = manualClock.currentTime + val blockInfos2 = addBlockInfos(tracker2) + tracker2.allocateBlocksToBatch(batchTime2) + tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + + // Verify whether log has correct contents + val expectedWrittenData2 = expectedWrittenData1 ++ + Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ + blockInfos2.map(BlockAdditionEvent) ++ + Seq(createBatchAllocation(batchTime2, blockInfos2)) + getWrittenLogData() shouldEqual expectedWrittenData2 + + // Restart tracker and verify recovered state + incrementTime() + val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + tracker3.getUnallocatedBlocks(streamId) shouldBe empty + + // Cleanup first batch but not second batch + val oldestLogFile = getWriteAheadLogFiles().head + incrementTime() + tracker3.cleanupOldBatches(batchTime2) + + // Verify that the batch allocations have been cleaned, and the act has been written to log + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty + getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1)) + + // Verify that at least one log file gets deleted + eventually(timeout(10 seconds), interval(10 millisecond)) { + getWriteAheadLogFiles() should not contain oldestLogFile + } + printLogFiles("After cleanup") + + // Restart tracker and verify recovered state, specifically whether info about the first + // batch has been removed, but not the second batch + incrementTime() + val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker4.getUnallocatedBlocks(streamId) shouldBe empty + tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned + tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + } + + /** + * Create tracker object with the optional provided clock. Use fake clock if you + * want to control time by manually incrementing it to test log cleanup. + */ + def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + allReceivedBlockTrackers += tracker + tracker + } + + /** Generate blocks infos using random ids */ + def generateBlockInfos(): Seq[ReceivedBlockInfo] = { + List.fill(5)(ReceivedBlockInfo(streamId, 0, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + } + + /** Get all the data written in the given write ahead log file. */ + def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { + getWrittenLogData(Seq(logFile)) + } + + /** + * Get all the data written in the given write ahead log files. By default, it will read all + * files in the test log directory. + */ + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + logFiles.flatMap { + file => new WriteAheadLogReader(file, hadoopConf).toSeq + }.map { byteBuffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.toList + } + + /** Get all the write ahead log files in the test directory */ + def getWriteAheadLogFiles(): Seq[String] = { + import ReceivedBlockTracker._ + val logDir = checkpointDirToLogDir(checkpointDirectory.toString) + getLogFilesInDirectory(logDir).map { _.toString } + } + + /** Create batch allocation object from the given info */ + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) + } + + /** Create batch cleanup object from the given info */ + def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { + BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) + } + + implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds) + + implicit def timeToMillis(time: Time): Long = time.milliseconds +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 10160244bcc91..d2b983c4b4d1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -117,12 +117,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { ) // Create the RDD and verify whether the returned data is correct - val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) assert(rdd.collect() === data.flatten) if (testStoreInBM) { - val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( From 5b3b6f6f5f029164d7749366506e142b104c1d43 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 10:33:13 -0800 Subject: [PATCH 07/68] [SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes --- .../mllib/JavaGradientBoostedTrees.java | 126 +++++++++++++ .../examples/mllib/DecisionTreeRunner.scala | 64 +++++-- .../examples/mllib/GradientBoostedTrees.scala | 146 +++++++++++++++ .../spark/mllib/tree/GradientBoosting.scala | 169 ++++++------------ .../tree/configuration/BoostingStrategy.scala | 78 ++++---- .../mllib/tree/configuration/Strategy.scala | 51 ++++-- .../mllib/tree/GradientBoostingSuite.scala | 34 ++-- 7 files changed, 462 insertions(+), 206 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java new file mode 100644 index 0000000000000..1af2067b2b929 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTrees { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTrees " + + " "); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.weakLearnerParams().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 49751a30491d0..63f02cf7b98b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -154,20 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") - val sc = new SparkContext(conf) - - println(s"DecisionTreeRunner with parameters:\n$params") - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -205,14 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. val numFeatures = examples.take(1)(0).features.size - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -229,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -338,7 +362,9 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + tree: WeightedEnsembleModel, + data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..9b6db01448be0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTrees { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTrees with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala index 1a847201ce157..f729344a682e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -17,30 +17,49 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ - +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} -import org.apache.spark.Logging -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements gradient boosting for regression and binary classification problems. + * A class that implements Stochastic Gradient Boosting + * for regression and binary classification problems. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * * @param boostingStrategy Parameters for the gradient boosting algorithm */ @Experimental class GradientBoosting ( private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + boostingStrategy.weakLearnerParams.algo = Regression + boostingStrategy.weakLearnerParams.impurity = impurity.Variance + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + boostingStrategy.weakLearnerParams.numClassesForClassification = + boostingStrategy.numClassesForClassification + + boostingStrategy.assertValid() + /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. @@ -51,6 +70,7 @@ class GradientBoosting ( algo match { case Regression => GradientBoosting.boost(input, boostingStrategy) case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoosting.boost(remappedInput, boostingStrategy) case _ => @@ -118,120 +138,32 @@ object GradientBoosting extends Logging { } /** - * Method to train a gradient boosting binary classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] */ - def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + train(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] */ def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainClassifier(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] */ def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainRegressor(input.rdd, boostingStrategy) } - /** * Internal method for performing regression using trees as base learners. * @param input training dataset @@ -247,15 +179,17 @@ object GradientBoosting extends Logging { timer.start("init") // Initialize gradient boosting parameters - val numEstimators = boostingStrategy.numEstimators - val baseLearners = new Array[DecisionTreeModel](numEstimators) - val baseLearnerWeights = new Array[Double](numEstimators) + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate val strategy = boostingStrategy.weakLearnerParams // Cache input - input.persist(StorageLevel.MEMORY_AND_DISK) + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } timer.stop("init") @@ -264,7 +198,7 @@ object GradientBoosting extends Logging { logDebug("##########") var data = input - // 1. Initialize tree + // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(strategy).train(data) baseLearners(0) = firstTreeModel @@ -280,7 +214,7 @@ object GradientBoosting extends Logging { point.features)) var m = 1 - while (m < numEstimators) { + while (m < numIterations) { timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) @@ -289,6 +223,9 @@ object GradientBoosting extends Logging { timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), @@ -305,8 +242,6 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - // 3. Output classifier new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 501d9ff9ea9b7..abbda040bd528 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,6 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** @@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. * @param numClassesForClassification Number of classes for classification. * (Ignored for regression.) + * This setting overrides any setting in [[weakLearnerParams]]. * Default value is 2 (binary classification). - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - algo: Algo, - @BeanProperty var numEstimators: Int, + @BeanProperty var algo: Algo, + @BeanProperty var numIterations: Int, @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var subsamplingRate: Double = 1.0, @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " + - s"$learningRate.") - require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " + - s"$learningRate.") - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo weakLearnerParams.numClassesForClassification = numClassesForClassification - weakLearnerParams.subsamplingRate = subsamplingRate + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification == 2) + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } } @Experimental @@ -82,28 +93,17 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: Algo): BoostingStrategy = { - val treeStrategy = defaultWeakLearnerParams(algo) + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy("Regression") + treeStrategy.maxDepth = 3 algo match { - case Classification => - new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy) - case Regression => - new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy) + case "Classification" => + new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + case "Regression" => + new BoostingStrategy(Algo.withName(algo), 100, SquaredError, + weakLearnerParams = treeStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } } - - /** - * Returns default configuration for the weak learner (decision tree) algorithm - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @return Configuration for weak learner - */ - def defaultWeakLearnerParams(algo: Algo): Strategy = { - // Note: Regression tree used even for classification for GBT. - new Strategy(Regression, Variance, 3) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d09295c507d67..b5b1f82177edc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Experimental class Strategy ( - val algo: Algo, + @BeanProperty var algo: Algo, @BeanProperty var impurity: Impurity, @BeanProperty var maxDepth: Int, @BeanProperty var numClassesForClassification: Int = 2, @@ -85,17 +85,9 @@ class Strategy ( @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -112,6 +104,23 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** * Check validity of parameters. * Throws exception if invalid. @@ -143,6 +152,26 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } +} + +@Experimental +object Strategy { + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala index 970fff82215e2..99a02eda60baf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -22,9 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.{Variance, Gini} +import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.LocalSparkContext @@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext class GradientBoostingSuite extends FunSuite with LocalSparkContext { test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, 1, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } } - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) @@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { object GradientBoostingSuite { // Combinations for estimators, learning rates and subsamplingRate - val testCombinations - = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) } From 4c42986cc070d9c5c55c7bf8a2a67585967b1082 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Wed, 5 Nov 2014 14:38:43 -0800 Subject: [PATCH 08/68] [SPARK-4242] [Core] Add SASL to external shuffle service Does three things: (1) Adds SASL to ExternalShuffleClient, (2) puts SecurityManager in BlockManager's constructor, and (3) adds unit test. Author: Aaron Davidson Closes #3108 from aarondav/sasl-client and squashes the following commits: 48b622d [Aaron Davidson] Screw it, let's just get LimitedInputStream 3543b70 [Aaron Davidson] Back out of pom change due to unknown test issue? b58518a [Aaron Davidson] ByteStreams.limit() not available :( cbe451a [Aaron Davidson] Address comments 2bf2908 [Aaron Davidson] [SPARK-4242] [Core] Add SASL to external shuffle service --- LICENSE | 21 +++- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 12 +- .../BlockManagerReplicationSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 4 +- network/common/pom.xml | 1 + .../buffer/FileSegmentManagedBuffer.java | 3 +- .../network/util/LimitedInputStream.java | 87 ++++++++++++++ network/shuffle/pom.xml | 1 + .../spark/network/sasl/SparkSaslClient.java | 1 - .../spark/network/sasl/SparkSaslServer.java | 9 +- .../shuffle/ExternalShuffleClient.java | 31 ++++- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 113 ++++++++++++++++++ .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- 15 files changed, 272 insertions(+), 23 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java diff --git a/LICENSE b/LICENSE index f1732fb47afc0..3c667bf45059a 100644 --- a/LICENSE +++ b/LICENSE @@ -754,7 +754,7 @@ SUCH DAMAGE. ======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): ======================================================================== Copyright (C) 2008 The Android Open Source Project @@ -771,6 +771,25 @@ See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For LimitedInputStream + (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): +======================================================================== +Copyright (C) 2007 The Guava Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 45e9d7f243e96..e7454beddbfd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -287,7 +287,7 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 655d16c65c8b5..a5fb87b9b2c51 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -72,7 +72,8 @@ private[spark] class BlockManager( val conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) + blockTransferService: BlockTransferService, + securityManager: SecurityManager) extends BlockDataManager with Logging { val diskBlockManager = new DiskBlockManager(this, conf) @@ -115,7 +116,8 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, + securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -166,9 +168,10 @@ private[spark] class BlockManager( conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) = { + blockTransferService: BlockTransferService, + securityManager: SecurityManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) } /** @@ -219,7 +222,6 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - val attemptsRemaining = logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1461fa69db90d..f63e772bf1e59 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) store.initialize("app-id") allStores += store store @@ -263,7 +263,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0782876c8e3c6..9529502bc8e10 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) manager.initialize("app-id") manager } @@ -795,7 +795,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer) + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/network/common/pom.xml b/network/common/pom.xml index ea887148d98ba..6144548a8f998 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -50,6 +50,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 89ed79bc63903..5fa1527ddff92 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; /** * A {@link ManagedBuffer} backed by a segment in a file. @@ -101,7 +102,7 @@ public InputStream createInputStream() throws IOException { try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return ByteStreams.limit(is, length); + return new LimitedInputStream(is, length); } catch (IOException e) { try { if (is != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 0000000000000..63ca43c046525 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index d271704d98a7a..fe5681d463499 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -51,6 +51,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 72ba737b998bc..9abad1f30a259 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -126,7 +126,6 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback logger.trace("SASL client callback: setting realm"); RealmCallback rc = (RealmCallback) callback; rc.setText(rc.getDefaultText()); - logger.info("Realm callback"); } else if (callback instanceof RealmChoiceCallback) { // ignore (?) } else { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 2c0ce40c75e80..e87b17ead1e1a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,7 +34,8 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; -import com.google.common.io.BaseEncoding; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,12 +160,14 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index b0b19ba67bddc..3aa95d00f6b20 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,12 +17,18 @@ package org.apache.spark.network.shuffle; +import java.util.List; + +import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import org.apache.spark.network.util.JavaUtils; @@ -37,18 +43,35 @@ public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); - private final TransportClientFactory clientFactory; + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; + private TransportClientFactory clientFactory; private String appId; - public ExternalShuffleClient(TransportConf conf) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); - this.clientFactory = context.createClientFactory(); + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; } @Override public void init(String appId) { this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); } @Override diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index bc101f53844d5..71e017b9e4e74 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -267,7 +267,7 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000000..4c18fcdfbcd88 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 0f27f55fec4f3..9efe15d01ed0c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -73,7 +73,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr)) + new NioBlockTransferService(conf, securityMgr), securityMgr) blockManager.initialize("app-id") tempDirectory = Files.createTempDir() From a46497eecc50f854c5c5701dc2b8a2468b76c085 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 5 Nov 2014 15:30:31 -0800 Subject: [PATCH 09/68] [SPARK-3984] [SPARK-3983] Fix incorrect scheduler delay and display task deserialization time in UI This commit fixes the scheduler delay in the UI (which previously included things that are not scheduler delay, like time to deserialize the task and serialize the result), and also adds information about time to deserialize tasks to the optional additional metrics. Time to deserialize the task can be large relative to task time for short jobs, and understanding when it is high can help developers realize that they should try to reduce closure size (e.g, by including less data in the task description). cc shivaram etrain Author: Kay Ousterhout Closes #2832 from kayousterhout/SPARK-3983 and squashes the following commits: 0c1398e [Kay Ousterhout] Fixed ordering 531575d [Kay Ousterhout] Removed executor launch time 1f13afe [Kay Ousterhout] Minor spacing fixes 335be4b [Kay Ousterhout] Made metrics hideable 5bc3cba [Kay Ousterhout] [SPARK-3984] [SPARK-3983] Improve UI task metrics. --- .../org/apache/spark/executor/Executor.scala | 4 +-- .../scala/org/apache/spark/ui/ToolTips.scala | 3 ++ .../org/apache/spark/ui/jobs/StagePage.scala | 31 ++++++++++++++++++- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index abc1dd0be6237..96114571d6c77 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -161,7 +161,7 @@ private[spark] class Executor( } override def run() { - val startTime = System.currentTimeMillis() + val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -206,7 +206,7 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime + m.executorDeserializeTime = taskStart - deserializeStartTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime m.resultSerializationTime = afterSerialization - beforeSerialization diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index f02904df31fcf..51dc08f668a43 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,6 +24,9 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" + val TASK_DESERIALIZATION_TIME = + """Time spent deserializating the task closure on the executor.""" + val INPUT = "Bytes read from Hadoop or from Spark storage." val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7cc03b7d333df..63ed5fc4949c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,6 +112,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Scheduler Delay +
  • + + + Task Deserialization Time + +
  • @@ -147,6 +154,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", TaskDetailsClassNames.GC_TIME), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ @@ -179,6 +187,17 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } } + val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + + + Task Deserialization Time + + +: getFormattedTimeQuantiles(deserializationTimes) + val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } @@ -266,6 +285,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val listings: Seq[Seq[Node]] = Seq( {serviceQuantiles}, {schedulerDelayQuantiles}, + + {deserializationQuantiles} + {gcQuantiles}, {serializationQuantiles} @@ -314,6 +336,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = info.gettingResultTime @@ -367,6 +390,10 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { class={TaskDetailsClassNames.SCHEDULER_DELAY}> {UIUtils.formatDuration(schedulerDelay.toLong)} + + {UIUtils.formatDuration(taskDeserializationTime.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} @@ -424,6 +451,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { (info.finishTime - info.launchTime) } } - totalExecutionTime - metrics.executorRunTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + totalExecutionTime - metrics.executorRunTime - executorOverhead } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 23d672cabda07..eb371bd0ea7ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -24,6 +24,7 @@ package org.apache.spark.ui.jobs private object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val GC_TIME = "gc_time" + val TASK_DESERIALIZATION_TIME = "deserialization_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } From f37817b18a479839b2e6118cc1cbd1059a94db52 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Wed, 5 Nov 2014 15:38:48 -0800 Subject: [PATCH 10/68] SPARK-4222 [CORE] use readFully in FixedLengthBinaryRecordReader replaces the existing read() call with readFully(). Author: industrial-sloth Closes #3093 from industrial-sloth/branch-1.2-fixedLenRecRdr and squashes the following commits: a245c8a [industrial-sloth] use readFully in FixedLengthBinaryRecordReader (cherry picked from commit 6844e7a8219ac78790a422ffd5054924e7d2bea1) Signed-off-by: Matei Zaharia --- .../org/apache/spark/input/FixedLengthBinaryRecordReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 5164a74bec4e9..36a1e5d475f46 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -115,7 +115,7 @@ private[spark] class FixedLengthBinaryRecordReader if (currentPosition < splitEnd) { // setup a buffer to store the record val buffer = recordValue.getBytes - fileInputStream.read(buffer, 0, recordLength) + fileInputStream.readFully(buffer) // update our current position currentPosition = currentPosition + recordLength // return true From 61a5cced049a8056292ba94f23fa7bd040f50685 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Nov 2014 15:42:05 -0800 Subject: [PATCH 11/68] [SPARK-3797] Run external shuffle service in Yarn NM This creates a new module `network/yarn` that depends on `network/shuffle` recently created in #3001. This PR introduces a custom Yarn auxiliary service that runs the external shuffle service. As of the changes here this shuffle service is required for using dynamic allocation with Spark. This is still WIP mainly because it doesn't handle security yet. I have tested this on a stable Yarn cluster. Author: Andrew Or Closes #3082 from andrewor14/yarn-shuffle-service and squashes the following commits: ef3ddae [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 0ee67a2 [Andrew Or] Minor wording suggestions 1c66046 [Andrew Or] Remove unused provided dependencies 0eb6233 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 6489db5 [Andrew Or] Try catch at the right places 7b71d8f [Andrew Or] Add detailed java docs + reword a few comments d1124e4 [Andrew Or] Add security to shuffle service (INCOMPLETE) 5f8a96f [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 9b6e058 [Andrew Or] Address various feedback f48b20c [Andrew Or] Fix tests again f39daa6 [Andrew Or] Do not make network-yarn an assembly module 761f58a [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 15a5b37 [Andrew Or] Fix build for Hadoop 1.x baff916 [Andrew Or] Fix tests 5bf9b7e [Andrew Or] Address a few minor comments 5b419b8 [Andrew Or] Add missing license header 804e7ff [Andrew Or] Include the Yarn shuffle service jar in the distribution cd076a4 [Andrew Or] Require external shuffle service for dynamic allocation ea764e0 [Andrew Or] Connect to Yarn shuffle service only if it's enabled 1bf5109 [Andrew Or] Use the shuffle service port specified through hadoop config b4b1f0c [Andrew Or] 4 tabs -> 2 tabs 43dcb96 [Andrew Or] First cut integration of shuffle service with Yarn aux service b54a0c4 [Andrew Or] Initial skeleton for Yarn shuffle service --- .../spark/ExecutorAllocationManager.scala | 37 +++- .../apache/spark/storage/BlockManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 16 ++ make-distribution.sh | 3 + .../network/sasl/ShuffleSecretManager.java | 117 ++++++++++++ network/yarn/pom.xml | 58 ++++++ .../network/yarn/YarnShuffleService.java | 176 ++++++++++++++++++ .../yarn/util/HadoopConfigProvider.java | 42 +++++ pom.xml | 2 + project/SparkBuild.scala | 8 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 16 ++ .../spark/deploy/yarn/ExecutorRunnable.scala | 16 ++ 12 files changed, 483 insertions(+), 16 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java create mode 100644 network/yarn/pom.xml create mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java create mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index c11f1db0064fd..ef93009a074e7 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Lower and upper bounds on the number of executors. These are required. private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) - verifyBounds() // How long there must be backlogged tasks for before an addition is triggered private val schedulerBacklogTimeout = conf.getLong( @@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) // How long an executor must be idle for before it is removed - private val removeThresholdSeconds = conf.getLong( + private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) + // During testing, the methods to actually kill and add executors are mocked out + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + validateSettings() + // Number of executors to add in the next round private var numExecutorsToAdd = 1 @@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Polling loop interval (ms) private val intervalMillis: Long = 100 - // Whether we are testing this class. This should only be used internally. - private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) - // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock /** - * Verify that the lower and upper bounds on the number of executors are valid. + * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. */ - private def verifyBounds(): Unit = { + private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") } @@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } + if (schedulerBacklogTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") + } + if (sustainedSchedulerBacklogTimeout <= 0) { + throw new SparkException( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") + } + if (executorIdleTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + } + // Require external shuffle service for dynamic allocation + // Otherwise, we may lose shuffle files when killing executors + if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } } /** @@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging val removeRequestAcknowledged = testing || sc.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging private def onExecutorIdle(executorId: String): Unit = synchronized { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)") - removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a5fb87b9b2c51..e48d7772d6ee9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -40,7 +40,6 @@ import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -97,7 +96,12 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) - private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337) + + // Port used by the external shuffle service. In Yarn mode, this may be already be + // set through the Hadoop configuration as the server is launched in the Yarn NM. + private val externalShuffleServicePort = + Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled && conf.getBoolean("spark.shuffle.consolidateFiles", false) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6ab94af9f3739..7caf6bcf94ef3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -45,6 +45,7 @@ import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -1780,6 +1781,21 @@ private[spark] object Utils extends Logging { val manifest = new JarManifest(manifestUrl.openStream()) manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) }.getOrElse("Unknown") + + /** + * Return the value of a config either through the SparkConf or the Hadoop configuration + * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf + * if the key is not set in the Hadoop configuration. + */ + def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { + val sparkValue = conf.get(key, default) + if (SparkHadoopUtil.get.isYarnMode) { + SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue) + } else { + sparkValue + } + } + } /** diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4d..fac7f7e284be4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-yarn*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-shuffle*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-common*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java new file mode 100644 index 0000000000000..e66c4af0f1ebd --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.lang.Override; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.sasl.SecretKeyHolder; + +/** + * A class that manages shuffle secret used by the external shuffle service. + */ +public class ShuffleSecretManager implements SecretKeyHolder { + private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); + private final ConcurrentHashMap shuffleSecretMap; + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + // Spark user used for authenticating SASL connections + // Note that this must match the value in org.apache.spark.SecurityManager + private static final String SPARK_SASL_USER = "sparkSaslUser"; + + /** + * Convert the given string to a byte buffer. The resulting buffer can be converted back to + * the same string through {@link #bytesToString(ByteBuffer)}. This is used if the external + * shuffle service represents shuffle secrets as bytes buffers instead of strings. + */ + public static ByteBuffer stringToBytes(String s) { + return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET)); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be converted back to + * the same byte buffer through {@link #stringToBytes(String)}. This is used if the external + * shuffle service represents shuffle secrets as bytes buffers instead of strings. + */ + public static String bytesToString(ByteBuffer b) { + return new String(b.array(), UTF8_CHARSET); + } + + public ShuffleSecretManager() { + shuffleSecretMap = new ConcurrentHashMap(); + } + + /** + * Register an application with its secret. + * Executors need to first authenticate themselves with the same secret before + * fetching shuffle files written by other executors in this application. + */ + public void registerApp(String appId, String shuffleSecret) { + if (!shuffleSecretMap.contains(appId)) { + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); + } else { + logger.debug("Application {} already registered", appId); + } + } + + /** + * Register an application with its secret specified as a byte buffer. + */ + public void registerApp(String appId, ByteBuffer shuffleSecret) { + registerApp(appId, bytesToString(shuffleSecret)); + } + + /** + * Unregister an application along with its secret. + * This is called when the application terminates. + */ + public void unregisterApp(String appId) { + if (shuffleSecretMap.contains(appId)) { + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); + } else { + logger.warn("Attempted to unregister application {} when it is not registered", appId); + } + } + + /** + * Return the Spark user for authenticating SASL connections. + */ + @Override + public String getSaslUser(String appId) { + return SPARK_SASL_USER; + } + + /** + * Return the secret key registered with the given application. + * This key is used to authenticate the executors before they can fetch shuffle files + * written by this application from the external shuffle service. If the specified + * application is not registered, return null. + */ + @Override + public String getSecretKey(String appId) { + return shuffleSecretMap.get(appId); + } +} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml new file mode 100644 index 0000000000000..e60d8c1f7876c --- /dev/null +++ b/network/yarn/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.10 + jar + Spark Project Yarn Shuffle Service Code + http://spark.apache.org/ + + network-yarn + + + + + + org.apache.spark + spark-network-shuffle_2.10 + ${project.version} + + + + + org.apache.hadoop + hadoop-client + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java new file mode 100644 index 0000000000000..bb0b8f7e6cba6 --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.lang.Override; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; +import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.ShuffleSecretManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.yarn.util.HadoopConfigProvider; + +/** + * An external shuffle service used by Spark on Yarn. + * + * This is intended to be a long-running auxiliary service that runs in the NodeManager process. + * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. + * The application also automatically derives the service port through `spark.shuffle.service.port` + * specified in the Yarn configuration. This is so that both the clients and the server agree on + * the same port to communicate on. + * + * The service also optionally supports authentication. This ensures that executors from one + * application cannot read the shuffle files written by those from another. This feature can be + * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. + * Note that the Spark application must also set `spark.authenticate` manually and, unlike in + * the case of the service port, will not inherit this setting from the Yarn configuration. This + * is because an application running on the same Yarn cluster may choose to not use the external + * shuffle service, in which case its setting of `spark.authenticate` should be independent of + * the service's. + */ +public class YarnShuffleService extends AuxiliaryService { + private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); + + // Port on which the shuffle server listens for fetch requests + private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; + private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; + + // Whether the shuffle server should authenticate fetch requests + private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; + private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + + // An entity that manages the shuffle secret per application + // This is used only if authentication is enabled + private ShuffleSecretManager secretManager; + + // The actual server that serves shuffle files + private TransportServer shuffleServer = null; + + public YarnShuffleService() { + super("spark_shuffle"); + logger.info("Initializing YARN shuffle service for Spark"); + } + + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + private boolean isAuthenticationEnabled() { + return secretManager != null; + } + + /** + * Start the shuffle server with the given configuration. + */ + @Override + protected void serviceInit(Configuration conf) { + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(); + if (authEnabled) { + secretManager = new ShuffleSecretManager(); + rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportContext transportContext = new TransportContext(transportConf, rpcHandler); + shuffleServer = transportContext.createServer(port); + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}.", port, authEnabledString); + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + String appId = context.getApplicationId().toString(); + try { + ByteBuffer shuffleSecret = context.getApplicationDataForService(); + logger.info("Initializing application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.registerApp(appId, shuffleSecret); + } + } catch (Exception e) { + logger.error("Exception when initializing application {}", appId, e); + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + String appId = context.getApplicationId().toString(); + try { + logger.info("Stopping application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.unregisterApp(appId); + } + } catch (Exception e) { + logger.error("Exception when stopping application {}", appId, e); + } + } + + @Override + public void initializeContainer(ContainerInitializationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Initializing container {}", containerId); + } + + @Override + public void stopContainer(ContainerTerminationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Stopping container {}", containerId); + } + + /** + * Close the shuffle server to clean up any associated state. + */ + @Override + protected void serviceStop() { + try { + if (shuffleServer != null) { + shuffleServer.close(); + } + } catch (Exception e) { + logger.error("Exception when stopping service", e); + } + } + + // Not currently used + @Override + public ByteBuffer getMetaData() { + return ByteBuffer.allocate(0); + } + +} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java new file mode 100644 index 0000000000000..884861752e80d --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn.util; + +import java.util.NoSuchElementException; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.spark.network.util.ConfigProvider; + +/** Use the Hadoop configuration to obtain config values. */ +public class HadoopConfigProvider extends ConfigProvider { + private final Configuration conf; + + public HadoopConfigProvider(Configuration conf) { + this.conf = conf; + } + + @Override + public String get(String name) { + String value = conf.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/pom.xml b/pom.xml index eb613531b8a5f..88ef67c515b3a 100644 --- a/pom.xml +++ b/pom.xml @@ -1229,6 +1229,7 @@ yarn-alpha yarn + network/yarn @@ -1236,6 +1237,7 @@ yarn yarn + network/yarn diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33618f5401768..657e4b4432775 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -38,9 +38,9 @@ object BuildCommons { "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") - .map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) @@ -143,7 +143,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach { + streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7ee4b5c842df1..5f47c79cabaee 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.sasl.ShuffleSecretManager @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -90,6 +91,21 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + ShuffleSecretManager.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b5a92d87d722..18f48b4b6caf6 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.sasl.ShuffleSecretManager class ExecutorRunnable( @@ -89,6 +90,21 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + ShuffleSecretManager.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) } From 868cd4c3ca11e6ecc4425b972d9a20c360b52425 Mon Sep 17 00:00:00 2001 From: "jay@apache.org" Date: Wed, 5 Nov 2014 15:45:34 -0800 Subject: [PATCH 12/68] SPARK-4040. Update documentation to exemplify use of local (n) value, fo... This is a minor docs update which helps to clarify the way local[n] is used for streaming apps. Author: jay@apache.org Closes #2964 from jayunit100/SPARK-4040 and squashes the following commits: 35b5a5e [jay@apache.org] SPARK-4040: Update documentation to exemplify use of local (n) value. --- docs/configuration.md | 10 ++++++++-- docs/streaming-programming-guide.md | 14 +++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 685101ea5c9c9..0f9eb81f6e993 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,16 +21,22 @@ application. These properties can be set directly on a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your `SparkContext`. `SparkConf` allows you to configure some of the common properties (e.g. master URL and application name), as well as arbitrary key-value pairs through the -`set()` method. For example, we could initialize an application as follows: +`set()` method. For example, we could initialize an application with two threads as follows: + +Note that we run with local[2], meaning two threads - which represents "minimal" parallelism, +which can help detect bugs that only exist when we run in a distributed context. {% highlight scala %} val conf = new SparkConf() - .setMaster("local") + .setMaster("local[2]") .setAppName("CountingSheep") .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} +Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually +require one to prevent any sort of starvation issues. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8bbba88b31978..44a1f3ad7560b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -68,7 +68,9 @@ import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// Create a local StreamingContext with two working thread and batch interval of 1 second +// Create a local StreamingContext with two working thread and batch interval of 1 second. +// The master requires 2 cores to prevent from a starvation scenario. + val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} @@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver]( A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are: -##### Points to remember: +##### Points to remember {:.no_toc} -- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. -- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data. - +- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. +- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor. +Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally. +See [Spark Properties] (configuration.html#spark-properties.html). + ### Basic Sources {:.no_toc} From f7ac8c2b1de96151231617846b7468d23379c74a Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Wed, 5 Nov 2014 15:49:42 -0800 Subject: [PATCH 13/68] SPARK-3223 runAsSparkUser cannot change HDFS write permission properly i... ...n mesos cluster mode - change master newer Author: Jongyoul Lee Closes #3034 from jongyoul/SPARK-3223 and squashes the following commits: 42b2ed3 [Jongyoul Lee] SPARK-3223 runAsSparkUser cannot change HDFS write permission properly in mesos cluster mode - change master newer --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d8c0e2f66df01..e4b859846035c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8e2faff90f9b2..7d097a3a7aaa3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() From cb0eae3b78d7f6f56c0b9521ee48564a4967d3de Mon Sep 17 00:00:00 2001 From: Brenden Matthews Date: Wed, 5 Nov 2014 16:02:44 -0800 Subject: [PATCH 14/68] [SPARK-4158] Fix for missing resources. Mesos offers may not contain all resources, and Spark needs to check to ensure they are present and sufficient. Spark may throw an erroneous exception when resources aren't present. Author: Brenden Matthews Closes #3024 from brndnmtthws/fix-mesos-resource-misuse and squashes the following commits: e5f9580 [Brenden Matthews] [SPARK-4158] Fix for missing resources. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 3 +-- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e4b859846035c..5289661eb896b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Build a Mesos resource protobuf object */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 7d097a3a7aaa3..c5f3493477bc5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Turn a Spark TaskDescription into a Mesos task */ From c315d1316cb2372e90ae3a12f72d5b3304435a6b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 19:51:18 -0800 Subject: [PATCH 15/68] [SPARK-4254] [mllib] MovieLensALS bug fix Changed code so it does not try to serialize Params. CC: mengxr debasish83 srowen Author: Joseph K. Bradley Closes #3116 from jkbradley/als-bugfix and squashes the following commits: e575bd8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into als-bugfix 9401b16 [Joseph K. Bradley] changed implicitPrefs so it is not serialized to fix MovieLensALS example bug --- .../scala/org/apache/spark/examples/mllib/MovieLensALS.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 8796c28db8a66..91a0a860d6c71 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -106,9 +106,11 @@ object MovieLensALS { Logger.getRootLogger.setLevel(Level.WARN) + val implicitPrefs = params.implicitPrefs + val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - if (params.implicitPrefs) { + if (implicitPrefs) { /* * MovieLens ratings are on a scale of 1-5: * 5: Must see From 3d2b5bc5bb979d8b0b71e06bc0f4548376fdbb98 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 5 Nov 2014 19:56:16 -0800 Subject: [PATCH 16/68] [SPARK-4262][SQL] add .schemaRDD to JavaSchemaRDD marmbrus Author: Xiangrui Meng Closes #3125 from mengxr/SPARK-4262 and squashes the following commits: 307695e [Xiangrui Meng] add .schemaRDD to JavaSchemaRDD --- .../scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 1e0ccb368a276..78e8d908fe0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -47,6 +47,9 @@ class JavaSchemaRDD( private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan) + /** Returns the underlying Scala SchemaRDD. */ + val schemaRDD: SchemaRDD = baseSchemaRDD + override val classTag = scala.reflect.classTag[Row] override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) From db45f5ad0368760dbeaa618a04f66ae9b2bed656 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 5 Nov 2014 20:45:35 -0800 Subject: [PATCH 17/68] [SPARK-4137] [EC2] Don't change working dir on user This issue was uncovered after [this discussion](https://issues.apache.org/jira/browse/SPARK-3398?focusedCommentId=14187471&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-14187471). Don't change the working directory on the user. This breaks relative paths the user may pass in, e.g., for the SSH identity file. ``` ./ec2/spark-ec2 -i ../my.pem ``` This patch will preserve the user's current working directory and allow calls like the one above to work. Author: Nicholas Chammas Closes #2988 from nchammas/spark-ec2-cwd and squashes the following commits: f3850b5 [Nicholas Chammas] pep8 fix fbc20c7 [Nicholas Chammas] revert to old commenting style 752f958 [Nicholas Chammas] specify deploy.generic path absolutely bcdf6a5 [Nicholas Chammas] fix typo 77871a2 [Nicholas Chammas] add clarifying comment ce071fc [Nicholas Chammas] don't change working dir --- ec2/spark-ec2 | 8 ++++++-- ec2/spark_ec2.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 31f9771223e51..4aa908242eeaa 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -18,5 +18,9 @@ # limitations under the License. # -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" +# Preserve the user's CWD so that relative paths are passed correctly to +#+ the underlying Python script. +SPARK_EC2_DIR="$(dirname $0)" + +PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \ + python "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 50f88f735650e..a5396c2375915 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -40,6 +40,7 @@ from boto import ec2 DEFAULT_SPARK_VERSION = "1.1.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information @@ -593,7 +594,14 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ) print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) + deploy_files( + conn=conn, + root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", + opts=opts, + master_nodes=master_nodes, + slave_nodes=slave_nodes, + modules=modules + ) print "Running setup on master..." setup_spark_cluster(master, opts) @@ -730,6 +738,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. +# +# root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): active_master = master_nodes[0].public_dns_name From 5f27ae16d5b016fae4afeb0f2ad779fd3130b390 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 6 Nov 2014 00:03:03 -0800 Subject: [PATCH 18/68] [SPARK-4255] Fix incorrect table striping This commit stripes table rows after hiding some rows, to ensure that rows are correct striped to alternate white and grey even when rows are hidden by default. Author: Kay Ousterhout Closes #3117 from kayousterhout/striping and squashes the following commits: be6e10a [Kay Ousterhout] [SPARK-4255] Fix incorrect table striping --- .../org/apache/spark/ui/static/additional-metrics.js | 2 ++ core/src/main/resources/org/apache/spark/ui/static/table.js | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index c5936b5038ac9..badd85ed48c82 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -39,6 +39,8 @@ $(function() { var column = "table ." + $(this).attr("name"); $(column).hide(); }); + // Stripe table rows after rows have been hidden to ensure correct striping. + stripeTables(); $("input:checkbox").click(function() { var column = "table ." + $(this).attr("name"); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 32187ba6e8df0..6bb03015abb51 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -28,8 +28,3 @@ function stripeTables() { }); }); } - -/* Stripe all tables after pages finish loading. */ -$(function() { - stripeTables(); -}); From b41a39e24038876359aeb7ce2bbbb4de2234e5f3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Nov 2014 00:22:19 -0800 Subject: [PATCH 19/68] [SPARK-4186] add binaryFiles and binaryRecords in Python add binaryFiles() and binaryRecords() in Python ``` binaryFiles(self, path, minPartitions=None): :: Developer API :: Read a directory of binary files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI as a byte array. Each file is read as a single record and returned in a key-value pair, where the key is the path of each file, the value is the content of each file. Note: Small files are preferred, large file is also allowable, but may cause bad performance. binaryRecords(self, path, recordLength): Load data from a flat binary file, assuming each record is a set of numbers with the specified numerical format (see ByteBuffer), and the number of bytes per record is constant. :param path: Directory to the input data files :param recordLength: The length at which to split the records ``` Author: Davies Liu Closes #3078 from davies/binary and squashes the following commits: cd0bdbd [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 3aa349b [Davies Liu] add experimental notes 24e84b6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 5ceaa8a [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 1900085 [Davies Liu] bugfix bb22442 [Davies Liu] add binaryFiles and binaryRecords in Python --- .../scala/org/apache/spark/SparkContext.scala | 4 ++ .../spark/api/java/JavaSparkContext.scala | 12 ++--- .../apache/spark/api/python/PythonRDD.scala | 45 ++++++++++++------- python/pyspark/context.py | 32 ++++++++++++- python/pyspark/tests.py | 19 ++++++++ 5 files changed, 90 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3cdaa6a9cc8a8..03ea672c813d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -560,6 +560,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** + * :: Experimental :: + * * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file * (useful for binary data) * @@ -602,6 +604,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { } /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e3aeba7e6c39d..5c6e8d32c5c8a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,11 +21,6 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import java.io.DataInputStream - -import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat} - import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -33,6 +28,7 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration +import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -286,6 +282,8 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.binaryFiles(path, minPartitions)) /** + * :: Experimental :: + * * Read a directory of binary files from HDFS, a local file system (available on all nodes), * or any Hadoop-supported file system URI as a byte array. Each file is read as a single * record and returned in a key-value pair, where the key is the path of each file, @@ -312,15 +310,19 @@ class JavaSparkContext(val sc: SparkContext) * * @note Small files are preferred; very large files but may cause bad performance. */ + @Experimental def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files * @return An RDD of data with values, represented as byte arrays */ + @Experimental def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { new JavaRDD(sc.binaryRecords(path, recordLength)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index e94ccdcd47bb7..45beb8fc8c925 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,6 +21,8 @@ import java.io._ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import org.apache.spark.input.PortableDataStream + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials @@ -395,22 +397,33 @@ private[spark] object PythonRDD extends Logging { newIter.asInstanceOf[Iterator[String]].foreach { str => writeUTF(str, dataOut) } - case pair: Tuple2[_, _] => - pair._1 match { - case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - } - case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) - } - case other => - throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + case stream: PortableDataStream => + newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, stream: PortableDataStream) => + newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { + case (key, stream) => + writeUTF(key, dataOut) + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, value: String) => + newIter.asInstanceOf[Iterator[(String, String)]].foreach { + case (key, value) => + writeUTF(key, dataOut) + writeUTF(value, dataOut) + } + case (key: Array[Byte], value: Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + dataOut.writeInt(key.length) + dataOut.write(key) + dataOut.writeInt(value.length) + dataOut.write(value) } case other => throw new SparkException("Unexpected element type " + first.getClass) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a0e4821728c8b..faa5952258aef 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -388,6 +388,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) + def binaryFiles(self, path, minPartitions=None): + """ + :: Experimental :: + + Read a directory of binary files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system URI + as a byte array. Each file is read as a single record and returned + in a key-value pair, where the key is the path of each file, the + value is the content of each file. + + Note: Small files are preferred, large file is also allowable, but + may cause bad performance. + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.binaryFiles(path, minPartitions), self, + PairDeserializer(UTF8Deserializer(), NoOpSerializer())) + + def binaryRecords(self, path, recordLength): + """ + :: Experimental :: + + Load data from a flat binary file, assuming each record is a set of numbers + with the specified numerical format (see ByteBuffer), and the number of + bytes per record is constant. + + :param path: Directory to the input data files + :param recordLength: The length at which to split the records + """ + return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer()) + def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7e61b017efa75..9f625c5c6ca48 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1110,6 +1110,25 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = "short binary data" + with open(os.path.join(path, "part-0000"), 'w') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(range(100), result) + class OutputFormatTests(ReusedPySparkTestCase): From 23eaf0e12ff221dcca40a79e61b6cc5e7c846cb5 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 10:45:46 -0800 Subject: [PATCH 20/68] [SPARK-4264] Completion iterator should only invoke callback once Author: Aaron Davidson Closes #3128 from aarondav/compiter and squashes the following commits: 698e4be [Aaron Davidson] [SPARK-4264] Completion iterator should only invoke callback once --- .../spark/util/CompletionIterator.scala | 5 +- .../spark/util/CompletionIteratorSuite.scala | 47 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index b6a099825f01b..390310243ee0a 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,10 +25,13 @@ private[spark] // scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { // scalastyle:on + + private[this] var completed = false def next() = sub.next() def hasNext = { val r = sub.hasNext - if (!r) { + if (!r && !completed) { + completed = true completion() } r diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala new file mode 100644 index 0000000000000..3755d43e25ea8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite + +class CompletionIteratorSuite extends FunSuite { + test("basic test") { + var numTimesCompleted = 0 + val iter = List(1, 2, 3).iterator + val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 }) + + assert(completionIter.hasNext) + assert(completionIter.next() === 1) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 2) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 3) + assert(numTimesCompleted === 0) + + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + + // SPARK-4264: Calling hasNext should not trigger the completion callback again. + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + } +} From d15c6e9dc2860bbe56e31ddf71218ccc6d5c841d Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Thu, 6 Nov 2014 10:46:45 -0800 Subject: [PATCH 21/68] [SPARK-4249][GraphX]fix a problem of EdgePartitionBuilder in Graphx at first srcIds is not initialized and are all 0. so we use edgeArray(0).srcId to currSrcId Author: lianhuiwang Closes #3138 from lianhuiwang/SPARK-4249 and squashes the following commits: 3f4e503 [lianhuiwang] fix a problem of EdgePartitionBuilder in Graphx --- .../org/apache/spark/graphx/impl/EdgePartitionBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb991515..2b6137be25547 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { srcIds(i) = edgeArray(i).srcId From 470881b24a503c9edcaed159c29bafa446ab0e9a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Nov 2014 15:31:07 -0800 Subject: [PATCH 22/68] [HOT FIX] Make distribution fails This was added by me in https://github.com/apache/spark/commit/61a5cced049a8056292ba94f23fa7bd040f50685. The real fix will be added in [SPARK-4281](https://issues.apache.org/jira/browse/SPARK-4281). Author: Andrew Or Closes #3145 from andrewor14/fix-make-distribution and squashes the following commits: c78be61 [Andrew Or] Hot fix make distribution --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index fac7f7e284be4..0bc839e1dbe4d 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,9 +181,6 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-yarn*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-shuffle*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-common*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" From 96136f222abd4f3abd10cb78a4ebecdb21f3bde7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Nov 2014 17:18:49 -0800 Subject: [PATCH 23/68] [SPARK-3797] Minor addendum to Yarn shuffle service I did not realize there was a `network.util.JavaUtils` when I wrote this code. This PR moves the `ByteBuffer` string conversion to the appropriate place. I tested the changes on a stable yarn cluster. Author: Andrew Or Closes #3144 from andrewor14/yarn-shuffle-util and squashes the following commits: b6c08bf [Andrew Or] Remove unused import 94e205c [Andrew Or] Use netty Unpooled 85202a5 [Andrew Or] Use guava Charsets 057135b [Andrew Or] Reword comment adf186d [Andrew Or] Move byte buffer String conversion logic to JavaUtils --- .../apache/spark/network/util/JavaUtils.java | 20 ++++++++++++++++ .../network/sasl/ShuffleSecretManager.java | 24 ++----------------- .../spark/deploy/yarn/ExecutorRunnable.scala | 5 ++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 5 ++-- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 40b71b0c87a47..2856d1c8c9337 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import java.nio.ByteBuffer; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; @@ -25,6 +27,8 @@ import java.io.ObjectOutputStream; import com.google.common.io.Closeables; +import com.google.common.base.Charsets; +import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,4 +77,20 @@ public static int nonNegativeHash(Object obj) { int hash = obj.hashCode(); return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; } + + /** + * Convert the given string to a byte buffer. The resulting buffer can be + * converted back to the same string through {@link #bytesToString(ByteBuffer)}. + */ + public static ByteBuffer stringToBytes(String s) { + return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be + * converted back to the same byte buffer through {@link #stringToBytes(String)}. + */ + public static String bytesToString(ByteBuffer b) { + return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index e66c4af0f1ebd..351c7930a900f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -19,13 +19,13 @@ import java.lang.Override; import java.nio.ByteBuffer; -import java.nio.charset.Charset; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; /** * A class that manages shuffle secret used by the external shuffle service. @@ -34,30 +34,10 @@ public class ShuffleSecretManager implements SecretKeyHolder { private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); private final ConcurrentHashMap shuffleSecretMap; - private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); - // Spark user used for authenticating SASL connections // Note that this must match the value in org.apache.spark.SecurityManager private static final String SPARK_SASL_USER = "sparkSaslUser"; - /** - * Convert the given string to a byte buffer. The resulting buffer can be converted back to - * the same string through {@link #bytesToString(ByteBuffer)}. This is used if the external - * shuffle service represents shuffle secrets as bytes buffers instead of strings. - */ - public static ByteBuffer stringToBytes(String s) { - return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET)); - } - - /** - * Convert the given byte buffer to a string. The resulting string can be converted back to - * the same byte buffer through {@link #stringToBytes(String)}. This is used if the external - * shuffle service represents shuffle secrets as bytes buffers instead of strings. - */ - public static String bytesToString(ByteBuffer b) { - return new String(b.array(), UTF8_CHARSET); - } - public ShuffleSecretManager() { shuffleSecretMap = new ConcurrentHashMap(); } @@ -80,7 +60,7 @@ public void registerApp(String appId, String shuffleSecret) { * Register an application with its secret specified as a byte buffer. */ public void registerApp(String appId, ByteBuffer shuffleSecret) { - registerApp(appId, bytesToString(shuffleSecret)); + registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 5f47c79cabaee..7023a1170654f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.util.JavaUtils @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -98,7 +98,8 @@ class ExecutorRunnable( val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { - ShuffleSecretManager.stringToBytes(secretString) + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) } else { // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 18f48b4b6caf6..fdd3c2300fa78 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -97,7 +97,8 @@ class ExecutorRunnable( val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { - ShuffleSecretManager.stringToBytes(secretString) + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) } else { // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) From 6e9ef10fd7446a11f37446c961916ba2a8e02cb8 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 17:20:46 -0800 Subject: [PATCH 24/68] [SPARK-4277] Support external shuffle service on Standalone Worker Author: Aaron Davidson Closes #3142 from aarondav/worker and squashes the following commits: 3780bd7 [Aaron Davidson] Address comments 2dcdfc1 [Aaron Davidson] Add private[worker] 47f49d3 [Aaron Davidson] NettyBlockTransferService shouldn't care about app ids (it's only b/t executors) 258417c [Aaron Davidson] [SPARK-4277] Support external shuffle service on executor --- .../org/apache/spark/SecurityManager.scala | 14 +--- .../StandaloneWorkerShuffleService.scala | 66 +++++++++++++++++++ .../apache/spark/deploy/worker/Worker.scala | 8 ++- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../NettyBlockTransferSecuritySuite.scala | 12 ---- .../spark/network/sasl/SaslMessage.java | 3 +- 6 files changed, 79 insertions(+), 26 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index dee935ffad51f..dbff9d12b5ad7 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -343,15 +343,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with */ def getSecretKey(): String = secretKey - override def getSaslUser(appId: String): String = { - val myAppId = sparkConf.getAppId - require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") - getSaslUser() - } - - override def getSecretKey(appId: String): String = { - val myAppId = sparkConf.getAppId - require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") - getSecretKey() - } + // Default SecurityManager only has a single secret key, so ignore appId. + override def getSaslUser(appId: String): String = getSaslUser() + override def getSecretKey(appId: String): String = getSecretKey() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala new file mode 100644 index 0000000000000..88118e2837741 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[worker] +class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) + private val blockHandler = new ExternalShuffleBlockHandler() + private val transportContext: TransportContext = { + val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler + new TransportContext(transportConf, handler) + } + + private var server: TransportServer = _ + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + } + + def stop() { + if (enabled && server != null) { + server.close() + server = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f1f66d0903f1c..ca262de832e25 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -111,6 +111,9 @@ private[spark] class Worker( val drivers = new HashMap[String, DriverRunner] val finishedDrivers = new HashMap[String, DriverRunner] + // The shuffle service is not actually started unless configured. + val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -154,6 +157,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -419,6 +423,7 @@ private[spark] class Worker( registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) + shuffleService.stop() webUi.stop() metricsSystem.stop() } @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging { cores: Int, memory: Int, masterUrls: Array[String], - workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workDir: String, + workerNumber: Option[Int] = None): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 1e579187e4193..6b1f57a069431 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -92,7 +92,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: FetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index bed0ed9d713dd..9162ec9801663 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -89,18 +89,6 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh } } - test("security mismatch app ids") { - val conf0 = new SparkConf() - .set("spark.authenticate", "true") - .set("spark.authenticate.secret", "good") - .set("spark.app.id", "app-id") - val conf1 = conf0.clone.set("spark.app.id", "other-id") - testConnection(conf0, conf1) match { - case Success(_) => fail("Should have failed") - case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") - } - } - /** * Creates two servers with different configurations and sees if they can talk. * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 5b77e18c26bf4..599cc6428c90e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -58,7 +58,8 @@ public void encode(ByteBuf buf) { public static SaslMessage decode(ByteBuf buf) { if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else"); + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); } int idLength = buf.readInt(); From f165b2bbf5d4acf34d826fa55b900f5bbc295654 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 18:39:14 -0800 Subject: [PATCH 25/68] [SPARK-4188] [Core] Perform network-level retry of shuffle file fetches This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException. This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load. TODO: - [x] unit tests - [x] put in ExternalShuffleClient too Author: Aaron Davidson Closes #3101 from aarondav/retry and squashes the following commits: 72a2a32 [Aaron Davidson] Add that we should remove the condition around the retry thingy c7fd107 [Aaron Davidson] Fix unit tests e80e4c2 [Aaron Davidson] Address initial comments 6f594cd [Aaron Davidson] Fix unit test 05ff43c [Aaron Davidson] Add to external shuffle client and add unit test 66e5a24 [Aaron Davidson] [SPARK-4238] [Core] Perform network-level retry of shuffle file fetches --- .../netty/NettyBlockTransferService.scala | 21 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportClientFactory.java | 13 +- .../client/TransportResponseHandler.java | 3 +- .../network/protocol/MessageEncoder.java | 2 +- .../spark/network/server/TransportServer.java | 8 +- .../apache/spark/network/util/NettyUtils.java | 14 +- .../spark/network/util/TransportConf.java | 17 + .../network/TransportClientFactorySuite.java | 7 +- .../shuffle/ExternalShuffleClient.java | 31 +- .../shuffle/OneForOneBlockFetcher.java | 9 +- .../network/shuffle/RetryingBlockFetcher.java | 234 +++++++++++++ .../network/sasl/SaslIntegrationSuite.java | 4 +- .../ExternalShuffleIntegrationSuite.java | 18 +- .../shuffle/ExternalShuffleSecuritySuite.java | 6 +- .../shuffle/RetryingBlockFetcherSuite.java | 310 ++++++++++++++++++ 16 files changed, 668 insertions(+), 45 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 0d1fc81d2a16f..b937ea825f49e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage listener: BlockFetchingListener): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { - val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, blockIds.toArray, listener) + .start(OpenBlocks(blockIds.map(BlockId.apply))) + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a08cee02dd576..4e944114e8176 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -18,7 +18,9 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import com.google.common.base.Objects; @@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); - callback.onFailure(new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -175,6 +185,8 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + throw Throwables.propagate(e.getCause()); } catch (Exception e) { throw Throwables.propagate(e); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 1723fed307257..397d3a8455c86 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -18,12 +18,12 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import com.google.common.base.Preconditions; @@ -44,7 +44,6 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -93,15 +92,17 @@ public TransportClientFactory( * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) { + public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); if (cachedClient != null) { if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); return cachedClient; } else { + logger.info("Found inactive connection to {}, closing it.", address); connectionPool.remove(address, cachedClient); // Remove inactive clients. } } @@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) { long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new RuntimeException( + throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { - throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } TransportClient client = clientRef.get(); @@ -198,7 +199,7 @@ public void close() { */ private PooledByteBufAllocator createPooledByteBufAllocator() { return new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), getPrivateStaticField("DEFAULT_PAGE_SIZE"), diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index d8965590b34da..2044afb0d85db 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.client; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -94,7 +95,7 @@ public void channelUnregistered() { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); - failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 4cb8becc3ed22..91d1e8a538a77 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; - ByteBuf header = ctx.alloc().buffer(headerLength); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); in.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 70da48ca8ee79..579676c2c3564 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -71,11 +72,14 @@ private void init(int portToBind) { NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; + PooledByteBufAllocator allocator = new PooledByteBufAllocator( + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + .option(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.ALLOCATOR, allocator); if (conf.backLog() > 0) { bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index b1872341198e0..2a7664fe89388 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -37,13 +37,17 @@ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. */ public class NettyUtils { - /** Creates a Netty EventLoopGroup based on the IOMode. */ - public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - - ThreadFactory threadFactory = new ThreadFactoryBuilder() + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ + public static ThreadFactory createThreadFactory(String threadPoolPrefix) { + return new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") + .setNameFormat(threadPoolPrefix + "-%d") .build(); + } + + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + ThreadFactory threadFactory = createThreadFactory(threadPrefix); switch (mode) { case NIO: diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 823790dd3c66f..787a8f0031af1 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) { /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + } + /** Connect timeout in secs. Default 120 secs. */ public int connectionTimeoutMs() { return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; @@ -58,4 +63,16 @@ public int connectionTimeoutMs() { /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 5a10fdb3842ef..822bef1d81b2a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.util.concurrent.TimeoutException; import org.junit.After; @@ -57,7 +58,7 @@ public void tearDown() { } @Test - public void createAndReuseBlockClients() throws TimeoutException { + public void createAndReuseBlockClients() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); @@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException { } @Test - public void neverReturnInactiveClients() throws Exception { + public void neverReturnInactiveClients() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception { } @Test - public void closeBlockClientsWithFactory() throws TimeoutException { + public void closeBlockClientsWithFactory() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 3aa95d00f6b20..27884b82c8cb9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.IOException; import java.util.List; import com.google.common.collect.Lists; @@ -76,17 +77,33 @@ public void init(String appId) { @Override public void fetchBlocks( - String host, - int port, - String execId, + final String host, + final int port, + final String execId, String[] blockIds, BlockFetchingListener listener) { assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { - TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = + new RetryingBlockFetcher.BlockFetchStarter() { + @Override + public void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, blockIds, listener) + .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); + } else { + blockFetchStarter.createAndStart(blockIds, listener); + } } catch (Exception e) { logger.error("Exception while beginning fetchBlocks", e); for (String blockId : blockIds) { @@ -108,7 +125,7 @@ public void registerWithShuffleServer( String host, int port, String execId, - ExecutorShuffleInfo executorInfo) { + ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 39b6f30f92baf..9e77a1f68c4b0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -51,9 +51,6 @@ public OneForOneBlockFetcher( TransportClient client, String[] blockIds, BlockFetchingListener listener) { - if (blockIds.length == 0) { - throw new IllegalArgumentException("Zero-sized blockIds array"); - } this.client = client; this.blockIds = blockIds; this.listener = listener; @@ -82,6 +79,10 @@ public void onFailure(int chunkIndex, Throwable e) { * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. */ public void start(Object openBlocksMessage) { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { @@ -95,7 +96,7 @@ public void onSuccess(byte[] response) { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } } catch (Exception e) { - logger.error("Failed while starting block fetches", e); + logger.error("Failed while starting block fetches after success", e); failRemainingBlocks(blockIds, e); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java new file mode 100644 index 0000000000000..f8a1a266863bb --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to + * IOExceptions, which we hope are due to transient network conditions. + * + * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In + * particular, the listener will be invoked exactly once per blockId, with a success or failure. + */ +public class RetryingBlockFetcher { + + /** + * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any + * remaining blocks. + */ + public static interface BlockFetchStarter { + /** + * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous + * bootstrapping followed by fully asynchronous block fetching. + * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this + * method must throw an exception. + * + * This method should always attempt to get a new TransportClient from the + * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection + * issues. + */ + void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + } + + /** Shared executor service used for waiting and retrying. */ + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Block Fetch Retry")); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); + + /** Used to initiate new Block Fetches on our remaining blocks. */ + private final BlockFetchStarter fetchStarter; + + /** Parent listener which we delegate all successful or permanently failed block fetches to. */ + private final BlockFetchingListener listener; + + /** Max number of times we are allowed to retry. */ + private final int maxRetries; + + /** Milliseconds to wait before each retry. */ + private final int retryWaitTime; + + // NOTE: + // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated + // while inside a synchronized block. + /** Number of times we've attempted to retry so far. */ + private int retryCount = 0; + + /** + * Set of all block ids which have not been fetched successfully or with a non-IO Exception. + * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, + * input ordering is preserved, so we always request blocks in the same order the user provided. + */ + private final LinkedHashSet outstandingBlocksIds; + + /** + * The BlockFetchingListener that is active with our current BlockFetcher. + * When we start a retry, we immediately replace this with a new Listener, which causes all any + * old Listeners to ignore all further responses. + */ + private RetryingBlockFetchListener currentListener; + + public RetryingBlockFetcher( + TransportConf conf, + BlockFetchStarter fetchStarter, + String[] blockIds, + BlockFetchingListener listener) { + this.fetchStarter = fetchStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTime(); + this.outstandingBlocksIds = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlocksIds, blockIds); + this.currentListener = new RetryingBlockFetchListener(); + } + + /** + * Initiates the fetch of all blocks provided in the constructor, with possible retries in the + * event of transient IOExceptions. + */ + public void start() { + fetchAllOutstanding(); + } + + /** + * Fires off a request to fetch all blocks that have not been fetched successfully or permanently + * failed (i.e., by a non-IOException). + */ + private void fetchAllOutstanding() { + // Start by retrieving our shared state within a synchronized block. + String[] blockIdsToFetch; + int numRetries; + RetryingBlockFetchListener myListener; + synchronized (this) { + blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. + try { + fetchStarter.createAndStart(blockIdsToFetch, myListener); + } catch (Exception e) { + logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", + blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); + + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid : blockIdsToFetch) { + listener.onBlockFetchFailure(bid, e); + } + } + } + } + + /** + * Lightweight method which initiates a retry in a different thread. The retry will involve + * calling fetchAllOutstanding() after a configured wait time. + */ + private synchronized void initiateRetry() { + retryCount += 1; + currentListener = new RetryingBlockFetchListener(); + + logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", + retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); + } + }); + } + + /** + * Returns true if we should retry due a block fetch failure. We will retry if and only if + * the exception was an IOException and we haven't retried 'maxRetries' times already. + */ + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null && e.getCause() instanceof IOException); + boolean hasRemainingRetries = retryCount < maxRetries; + return isIOException && hasRemainingRetries; + } + + /** + * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. + * Note that in the event of a retry, we will immediately replace the 'currentListener' field, + * indicating that any responses from non-current Listeners should be ignored. + */ + private class RetryingBlockFetchListener implements BlockFetchingListener { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data); + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + // We will only forward this failure to our parent listener if this block request is + // outstanding, we are still the active listener, AND we cannot retry the fetch. + boolean shouldForwardFailure = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", + blockId, retryCount), exception); + outstandingBlocksIds.remove(blockId); + shouldForwardFailure = true; + } + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardFailure) { + listener.onBlockFetchFailure(blockId, exception); + } + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 84781207861ed..d25283e46ef96 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -93,7 +93,7 @@ public void afterEach() { } @Test - public void testGoodClient() { + public void testGoodClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList( new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); @@ -119,7 +119,7 @@ public void testBadClient() { } @Test - public void testNoSaslClient() { + public void testNoSaslClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 71e017b9e4e74..06294fef19621 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -259,14 +259,20 @@ public void testFetchUnregisteredExecutor() throws Exception { @Test public void testFetchNoServer() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } } - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 4c18fcdfbcd88..848c88f743d50 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle; +import java.io.IOException; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -54,7 +56,7 @@ public void afterEach() { } @Test - public void testValid() { + public void testValid() throws IOException { validate("my-app-id", "secret"); } @@ -77,7 +79,7 @@ public void testBadSecret() { } /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey) { + private void validate(String appId, String secretKey) throws IOException { ExternalShuffleClient client = new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); client.init(appId); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 0000000000000..0191fe529e1be --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWaitMs", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWaitMs"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} From 48a19a6dba896f7d0b637f84e114b7efbb814e51 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 19:54:32 -0800 Subject: [PATCH 26/68] [SPARK-4236] Cleanup removed applications' files in shuffle service This relies on a hook from whoever is hosting the shuffle service to invoke removeApplication() when the application is completed. Once invoked, we will clean up all the executors' shuffle directories we know about. Author: Aaron Davidson Closes #3126 from aarondav/cleanup and squashes the following commits: 33a64a9 [Aaron Davidson] Missing brace e6e428f [Aaron Davidson] Address comments 16a0d27 [Aaron Davidson] Cleanup e4df3e7 [Aaron Davidson] [SPARK-4236] Cleanup removed applications' files in shuffle service --- .../scala/org/apache/spark/util/Utils.scala | 1 + .../spark/ExternalShuffleServiceSuite.scala | 5 +- .../apache/spark/network/util/JavaUtils.java | 59 ++++++++ .../shuffle/ExternalShuffleBlockHandler.java | 10 +- .../shuffle/ExternalShuffleBlockManager.java | 118 +++++++++++++-- .../shuffle/ExternalShuffleCleanupSuite.java | 142 ++++++++++++++++++ .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/TestShuffleDataContext.java | 4 +- 8 files changed, 319 insertions(+), 22 deletions(-) create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7caf6bcf94ef3..2cbd38d72caa1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -755,6 +755,7 @@ private[spark] object Utils extends Logging { /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. */ def deleteRecursively(file: File) { if (file != null) { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 792b9cd8b6ff2..6608ed1e57b38 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -63,8 +63,9 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { rdd.count() rdd.count() - // Invalidate the registered executors, disallowing access to their shuffle blocks. - rpcHandler.clearRegisteredExecutors() + // Invalidate the registered executors, disallowing access to their shuffle blocks (without + // deleting the actual shuffle files, so we could access them without the shuffle service). + rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry" // being set. diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 2856d1c8c9337..75c4a3981a240 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -22,16 +22,22 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; +import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import com.google.common.base.Preconditions; import com.google.common.io.Closeables; import com.google.common.base.Charsets; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * General utilities available in the network package. Many of these are sourced from Spark's + * own Utils, just accessible within this package. + */ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); @@ -93,4 +99,57 @@ public static ByteBuffer stringToBytes(String s) { public static String bytesToString(ByteBuffer b) { return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); } + + /* + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + public static void deleteRecursively(File file) throws IOException { + if (file == null) { return; } + + if (file.isDirectory() && !isSymlink(file)) { + IOException savedIOException = null; + for (File child : listFilesSafely(file)) { + try { + deleteRecursively(child); + } catch (IOException e) { + // In case of multiple exceptions, only last one will be thrown + savedIOException = e; + } + } + if (savedIOException != null) { + throw savedIOException; + } + } + + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } + } + + private static File[] listFilesSafely(File file) throws IOException { + if (file.exists()) { + File[] files = file.listFiles(); + if (files == null) { + throw new IOException("Failed to list files for dir: " + file); + } + return files; + } else { + return new File[0]; + } + } + + private static boolean isSymlink(File file) throws IOException { + Preconditions.checkNotNull(file); + File fileInCanonicalDir = null; + if (file.getParent() == null) { + fileInCanonicalDir = file; + } else { + fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); + } + return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index cd3fea85b19a4..75ebf8c7b0604 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -94,9 +94,11 @@ public StreamManager getStreamManager() { return streamManager; } - /** For testing, clears all executors registered with "RegisterExecutor". */ - @VisibleForTesting - public void clearRegisteredExecutors() { - blockManager.clearRegisteredExecutors(); + /** + * Removes an application (once it has been terminated), and optionally will clean up any + * local directories associated with the executors of that application in a separate thread. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + blockManager.applicationRemoved(appId, cleanupLocalDirs); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 6589889fe1be7..98fcfb82aa5d1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -21,9 +21,15 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,13 +49,22 @@ public class ExternalShuffleBlockManager { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); - // Map from "appId-execId" to the executor's configuration. - private final ConcurrentHashMap executors = - new ConcurrentHashMap(); + // Map containing all registered executors' metadata. + private final ConcurrentMap executors; - // Returns an id suitable for a single executor within a single application. - private String getAppExecId(String appId, String execId) { - return appId + "-" + execId; + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + public ExternalShuffleBlockManager() { + // TODO: Give this thread a name. + this(Executors.newSingleThreadExecutor()); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockManager(Executor directoryCleaner) { + this.executors = Maps.newConcurrentMap(); + this.directoryCleaner = directoryCleaner; } /** Registers a new Executor with all the configuration we need to find its shuffle files. */ @@ -57,7 +72,7 @@ public void registerExecutor( String appId, String execId, ExecutorShuffleInfo executorInfo) { - String fullId = getAppExecId(appId, execId); + AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); executors.put(fullId, executorInfo); } @@ -78,7 +93,7 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { int mapId = Integer.parseInt(blockIdParts[2]); int reduceId = Integer.parseInt(blockIdParts[3]); - ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId)); + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); @@ -94,6 +109,56 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { } } + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + /** * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockManager. @@ -146,9 +211,36 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } - /** For testing, clears all registered executors. */ - @VisibleForTesting - void clearRegisteredExecutors() { - executors.clear(); + /** Simply encodes an executor's full ID, which is appId + execId. */ + private static class AppExecId { + final String appId; + final String execId; + + private AppExecId(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 0000000000000..c8ece3bc53ac3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + manager.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 06294fef19621..3bea5b0f253c6 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -105,7 +105,7 @@ public static void afterAll() { @After public void afterEach() { - handler.clearRegisteredExecutors(); + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); } class FetchResult { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 442b756467442..337b5c7bdb5da 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -30,8 +30,8 @@ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. */ public class TestShuffleDataContext { - private final String[] localDirs; - private final int subDirsPerLocalDir; + public final String[] localDirs; + public final int subDirsPerLocalDir; public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { this.localDirs = new String[numLocalDirs]; From 3abdb1b24aa48f21e7eed1232c01d3933873688c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Nov 2014 21:52:12 -0800 Subject: [PATCH 27/68] [SPARK-4204][Core][WebUI] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly This PR fixed `Utils.exceptionString` to output the full exception information. However, the stack trace may become very huge, so I also updated the Web UI to collapse the error information by default (display the first line and clicking `+detail` will display the full info). Here are the screenshots: Stages: ![stages](https://cloud.githubusercontent.com/assets/1000778/4882441/66d8cc68-6356-11e4-8346-6318677d9470.png) Details for one stage: ![stage](https://cloud.githubusercontent.com/assets/1000778/4882513/1311043c-6357-11e4-8804-ca14240a9145.png) The full information in the gray text field is: ```Java org.apache.spark.shuffle.FetchFailedException: Connection reset by peer at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$.org$apache$spark$shuffle$hash$BlockStoreShuffleFetcher$$unpackBlock$1(BlockStoreShuffleFetcher.scala:67) at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$$anonfun$3.apply(BlockStoreShuffleFetcher.scala:83) at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$$anonfun$3.apply(BlockStoreShuffleFetcher.scala:83) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:30) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:39) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327) at org.apache.spark.util.collection.ExternalAppendOnlyMap.insertAll(ExternalAppendOnlyMap.scala:129) at org.apache.spark.rdd.CoGroupedRDD$$anonfun$compute$5.apply(CoGroupedRDD.scala:160) at org.apache.spark.rdd.CoGroupedRDD$$anonfun$compute$5.apply(CoGroupedRDD.scala:159) at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:772) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:771) at org.apache.spark.rdd.CoGroupedRDD.compute(CoGroupedRDD.scala:159) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.rdd.MappedValuesRDD.compute(MappedValuesRDD.scala:31) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.rdd.FlatMappedValuesRDD.compute(FlatMappedValuesRDD.scala:31) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61) at org.apache.spark.scheduler.Task.run(Task.scala:56) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:189) at java.util.concurrent.ThreadPoolExecutor$Worker.runTask(ThreadPoolExecutor.java:886) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:908) at java.lang.Thread.run(Thread.java:662) Caused by: java.io.IOException: Connection reset by peer at sun.nio.ch.FileDispatcher.read0(Native Method) at sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:21) at sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:198) at sun.nio.ch.IOUtil.read(IOUtil.java:166) at sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:245) at io.netty.buffer.PooledUnsafeDirectByteBuf.setBytes(PooledUnsafeDirectByteBuf.java:311) at io.netty.buffer.AbstractByteBuf.writeBytes(AbstractByteBuf.java:881) at io.netty.channel.socket.nio.NioSocketChannel.doReadBytes(NioSocketChannel.java:225) at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:119) at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511) at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468) at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354) at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:116) ... 1 more ``` /cc aarondav Author: zsxwing Closes #3073 from zsxwing/SPARK-4204 and squashes the following commits: 176d1e3 [zsxwing] Add comments to explain the stack trace difference ca509d3 [zsxwing] Add fullStackTrace to the constructor of ExceptionFailure a07057b [zsxwing] Core style fix dfb0032 [zsxwing] Backward compatibility for old history server 1e50f71 [zsxwing] Update as per review and increase the max height of the stack trace details 94f2566 [zsxwing] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly --- .../org/apache/spark/ui/static/webui.css | 14 ++++++++ .../org/apache/spark/TaskEndReason.scala | 35 ++++++++++++++++++- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 +-- .../spark/shuffle/FetchFailedException.scala | 17 +++++++-- .../hash/BlockStoreShuffleFetcher.scala | 5 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 32 +++++++++++++++-- .../org/apache/spark/ui/jobs/StageTable.scala | 28 +++++++++++++-- .../org/apache/spark/util/JsonProtocol.scala | 5 ++- .../scala/org/apache/spark/util/Utils.scala | 24 ++++++------- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 10 +++++- 12 files changed, 148 insertions(+), 30 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a2220e761ac98..db57712c83503 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -120,6 +120,20 @@ pre { border: none; } +.stacktrace-details { + max-height: 300px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stacktrace-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} + span.expand-additional-metrics { cursor: pointer; } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index f45b463fb6f62..af5fd8e0ac00c 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -83,15 +83,48 @@ case class FetchFailed( * :: DeveloperApi :: * Task failed due to a runtime exception. This is the most common failure case and also captures * user program exceptions. + * + * `stackTrace` contains the stack trace of the exception itself. It still exists for backward + * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to + * create `ExceptionFailure` as it will handle the backward compatibility properly. + * + * `fullStackTrace` is a better representation of the stack trace because it contains the whole + * stack trace including the exception and its causes */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], + fullStackTrace: String, metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) + + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + } + + override def toErrorString: String = + if (fullStackTrace == null) { + // fullStackTrace is added in 1.2.0 + // If fullStackTrace is null, use the old error string for backward compatibility + exceptionString(className, description, stackTrace) + } else { + fullStackTrace + } + + /** + * Return a nice string representation of the exception, including the stack trace. + * Note: It does not include the exception's causes, and is only used for backward compatibility. + */ + private def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 96114571d6c77..caf4d76713d49 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -263,7 +263,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 96114c0423a9e..22449517d100f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1063,7 +1063,7 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage)) + markStageAsFinished(failedStage, Some(failureMessage)) runningStages -= failedStage } @@ -1094,7 +1094,7 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - case ExceptionFailure(className, description, stackTrace, metrics) => + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 0c1b6f4defdb3..be184464e0ae9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -32,10 +32,21 @@ private[spark] class FetchFailedException( shuffleId: Int, mapId: Int, reduceId: Int, - message: String) - extends Exception(message) { + message: String, + cause: Throwable = null) + extends Exception(message, cause) { + + def this( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int, + cause: Throwable) { + this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + Utils.exceptionString(this)) } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0d5247f4176d4..e3e7434df45b0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -25,7 +25,7 @@ import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { def fetch[T]( @@ -64,8 +64,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, - Utils.exceptionString(e)) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 63ed5fc4949c2..250bddbe2f262 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ @@ -436,13 +438,37 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {diskBytesSpilledReadable} }} - - {errorMessage.map { e =>
    {e}
    }.getOrElse("")} - + {errorMessageCell(errorMessage)} } } + private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { + val error = errorMessage.getOrElse("") + val isMultiline = error.indexOf('\n') >= 0 + // Display the first line by default + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + error.substring(0, error.indexOf('\n')) + } else { + error + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + {errorSummary}{details} + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = { if (info.gettingResultTime > 0) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 4ee7f08ab47a2..3b4866e05956d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,6 +22,8 @@ import scala.xml.Text import java.util.Date +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -195,7 +197,29 @@ private[ui] class FailedStageTable( override protected def stageRow(s: StageInfo): Seq[Node] = { val basicColumns = super.stageRow(s) - val failureReason =
    {s.failureReason.getOrElse("")}
    - basicColumns ++ failureReason + val failureReason = s.failureReason.getOrElse("") + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + val failureReasonHtml = {failureReasonSummary}{details} + basicColumns ++ failureReasonHtml } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f7ae1f7f334de..f15d0c856663f 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -287,6 +287,7 @@ private[spark] object JsonProtocol { ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ + ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) case ExecutorLostFailure(executorId) => ("Executor ID" -> executorId) @@ -637,8 +638,10 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") + val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). + map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - new ExceptionFailure(className, description, stackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2cbd38d72caa1..a14d6125484fe 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1599,19 +1599,19 @@ private[spark] object Utils extends Logging { .orNull } - /** Return a nice string representation of the exception, including the stack trace. */ + /** + * Return a nice string representation of the exception. It will call "printStackTrace" to + * recursively generate the stack trace including the exception and its causes. + */ def exceptionString(e: Throwable): String = { - if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) - } - - /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString( - className: String, - description: String, - stackTrace: Array[StackTraceElement]): String = { - val desc = if (description == null) "" else description - val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") - s"$className: $desc\n$st" + if (e == null) { + "" + } else { + // Use e.printStackTrace here because e.getStackTrace doesn't include the cause + val stringWriter = new StringWriter() + e.printStackTrace(new PrintWriter(stringWriter)) + stringWriter.toString + } } /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 2efbae689771a..2608ad4b32e1e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -116,7 +116,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - new ExceptionFailure("Exception", "description", null, None), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index aec1e409db95c..39e69851e7e3c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -109,7 +109,7 @@ class JsonProtocolSuite extends FunSuite { // TaskEndReason val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "Some exception") - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) + val exceptionFailure = new ExceptionFailure(exception, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -127,6 +127,13 @@ class JsonProtocolSuite extends FunSuite { testBlockId(StreamBlockId(1, 2L)) } + test("ExceptionFailure backward compatibility") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) + .removeField({ _._1 == "Full Stack Trace" }) + assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("StageInfo backward compatibility") { val info = makeStageInfo(1, 2, 3, 4L, 5L) val newJson = JsonProtocol.stageInfoToJson(info) @@ -422,6 +429,7 @@ class JsonProtocolSuite extends FunSuite { assert(r1.className === r2.className) assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) + assert(r1.fullStackTrace === r2.fullStackTrace) assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => From d4fa04e50d299e9cad349b3781772956453a696b Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 7 Nov 2014 09:42:21 -0800 Subject: [PATCH 28/68] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages This PR elimiantes the network package's usage of the Java serializer and replaces it with Encodable, which is a lightweight binary protocol. Each message is preceded by a type id, which will allow us to change messages (by only adding new ones), or to change the format entirely by switching to a special id (such as -1). This protocol has the advantage over Java that we can guarantee that messages will remain compatible across compiled versions and JVMs, though it does not provide a clean way to do schema migration. In the future, it may be good to use a more heavy-weight serialization format like protobuf, thrift, or avro, but these all add several dependencies which are unnecessary at the present time. Additionally this unifies the RPC messages of NettyBlockTransferService and ExternalShuffleClient. Author: Aaron Davidson Closes #3146 from aarondav/free and squashes the following commits: ed1102a [Aaron Davidson] Remove some unused imports b8e2a49 [Aaron Davidson] Add appId to test 538f2a3 [Aaron Davidson] [SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages --- .../spark/network/BlockTransferService.scala | 4 +- .../network/netty/NettyBlockRpcServer.scala | 31 ++--- .../netty/NettyBlockTransferService.scala | 15 ++- .../network/nio/NioBlockTransferService.scala | 1 + .../apache/spark/storage/BlockManager.scala | 5 +- .../NettyBlockTransferSecuritySuite.scala | 4 +- .../network/protocol/ChunkFetchFailure.java | 12 +- .../spark/network/protocol/Encoders.java | 93 ++++++++++++++ .../spark/network/protocol/RpcFailure.java | 12 +- .../spark/network/protocol/RpcRequest.java | 9 +- .../spark/network/protocol/RpcResponse.java | 9 +- .../apache/spark/network/util/JavaUtils.java | 27 ----- .../spark/network/sasl/SaslMessage.java | 24 ++-- .../shuffle/ExternalShuffleBlockHandler.java | 21 ++-- .../shuffle/ExternalShuffleBlockManager.java | 1 + .../shuffle/ExternalShuffleClient.java | 12 +- .../shuffle/ExternalShuffleMessages.java | 106 ---------------- .../shuffle/OneForOneBlockFetcher.java | 17 ++- .../protocol/BlockTransferMessage.java | 76 ++++++++++++ .../{ => protocol}/ExecutorShuffleInfo.java | 36 +++++- .../network/shuffle/protocol/OpenBlocks.java | 87 ++++++++++++++ .../shuffle/protocol/RegisterExecutor.java | 91 ++++++++++++++ .../StreamHandle.java} | 34 ++++-- .../network/shuffle/protocol/UploadBlock.java | 113 ++++++++++++++++++ ...e.java => BlockTransferMessagesSuite.java} | 33 ++--- .../ExternalShuffleBlockHandlerSuite.java | 29 ++--- .../ExternalShuffleIntegrationSuite.java | 1 + .../shuffle/ExternalShuffleSecuritySuite.java | 1 + .../shuffle/OneForOneBlockFetcherSuite.java | 18 +-- .../shuffle/TestShuffleDataContext.java | 2 + 30 files changed, 640 insertions(+), 284 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java rename network/shuffle/src/main/java/org/apache/spark/network/shuffle/{ => protocol}/ExecutorShuffleInfo.java (68%) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java rename network/shuffle/src/main/java/org/apache/spark/network/shuffle/{ShuffleStreamHandle.java => protocol/StreamHandle.java} (65%) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java rename network/shuffle/src/test/java/org/apache/spark/network/shuffle/{ShuffleMessagesSuite.java => BlockTransferMessagesSuite.java} (55%) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 210a581db466e..dcbda5a8515dd 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo def uploadBlockSync( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { - Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 1950e7bd634ee..b089da8596e2b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.ShuffleStreamHandle +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} -object NettyMessages { - /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ - case class OpenBlocks(blockIds: Seq[BlockId]) - - /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ - case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) -} - /** * Serves requests to open blocks by simply registering one chunk per block requested. * Handles opening and uploading arbitrary BlockManager blocks. @@ -50,28 +42,29 @@ class NettyBlockRpcServer( blockManager: BlockDataManager) extends RpcHandler with Logging { - import NettyMessages._ - private val streamManager = new OneForOneStreamManager() override def receive( client: TransportClient, messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { - val ser = serializer.newInstance() - val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) logTrace(s"Received request: $message") message match { - case OpenBlocks(blockIds) => - val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + case openBlocks: OpenBlocks => + val blocks: Seq[ManagedBuffer] = + openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(blocks.iterator) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess( - ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) - case UploadBlock(blockId, blockData, level) => - blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + case uploadBlock: UploadBlock => + // StorageLevel is serialized as bytes using our JavaSerializer. + val level: StorageLevel = + serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) + blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) responseContext.onSuccess(new Array[Byte](0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b937ea825f49e..f8a7f640689a2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -24,10 +24,10 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} -import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -46,6 +46,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ + private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { @@ -60,6 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage transportContext = new TransportContext(transportConf, rpcHandler) clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() + appId = conf.getAppId logInfo("Server created on " + server.getPort) } @@ -74,8 +76,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() } } @@ -101,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { val result = Promise[Unit]() val client = clientFactory.createClient(hostname, port) + // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded + // using our binary protocol. + val levelBytes = serializer.newInstance().serialize(level).array() + // Convert or copy nio buffer into array in order to serialize it. val nioBuffer = blockData.nioByteBuffer() val array = if (nioBuffer.hasArray) { @@ -117,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - val ser = serializer.newInstance() - client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index f56d165daba55..b2aec160635c7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e48d7772d6ee9..39434f473a9d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -35,7 +35,8 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} -import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient} +import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager @@ -939,7 +940,7 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel) logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 9162ec9801663..530f5d6db5a29 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -36,7 +36,9 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMat class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { - testConnection(new SparkConf, new SparkConf) match { + val conf = new SparkConf() + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { case Success(_) => // expected case Failure(t) => fail(t) } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 152af98ced7ce..986957c1509fd 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; + return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static ChunkFetchFailure decode(ByteBuf buf) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new ChunkFetchFailure(streamChunkId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java new file mode 100644 index 0000000000000..873c694250942 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** Provides a canonical set of Encoders for simple types. */ +public class Encoders { + + /** Strings are encoded with their length followed by UTF-8 bytes. */ + public static class Strings { + public static int encodedLength(String s) { + return 4 + s.getBytes(Charsets.UTF_8).length; + } + + public static void encode(ByteBuf buf, String s) { + byte[] bytes = s.getBytes(Charsets.UTF_8); + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + } + + public static String decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return new String(bytes, Charsets.UTF_8); + } + } + + /** Byte arrays are encoded with their length followed by bytes. */ + public static class ByteArrays { + public static int encodedLength(byte[] arr) { + return 4 + arr.length; + } + + public static void encode(ByteBuf buf, byte[] arr) { + buf.writeInt(arr.length); + buf.writeBytes(arr); + } + + public static byte[] decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return bytes; + } + } + + /** String arrays are encoded with the number of strings followed by per-String encoding. */ + public static class StringArrays { + public static int encodedLength(String[] strings) { + int totalLength = 4; + for (String s : strings) { + totalLength += Strings.encodedLength(s); + } + return totalLength; + } + + public static void encode(ByteBuf buf, String[] strings) { + buf.writeInt(strings.length); + for (String s : strings) { + Strings.encode(buf, s); + } + } + + public static String[] decode(ByteBuf buf) { + int numStrings = buf.readInt(); + String[] strings = new String[numStrings]; + for (int i = 0; i < strings.length; i ++) { + strings[i] = Strings.decode(buf); + } + return strings; + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index e239d4ffbd29c..ebd764eb5eb5f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) { @Override public int encodedLength() { - return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; + return 8 + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static RpcFailure decode(ByteBuf buf) { long requestId = buf.readLong(); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new RpcFailure(requestId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 099e934ae018c..cdee0b0e0316b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + 4 + message.length; + return 8 + Encoders.ByteArrays.encodedLength(message); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(message.length); - buf.writeBytes(message); + Encoders.ByteArrays.encode(buf, message); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - int messageLen = buf.readInt(); - byte[] message = new byte[messageLen]; - buf.readBytes(message); + byte[] message = Encoders.ByteArrays.decode(buf); return new RpcRequest(requestId, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index ed479478325b6..0a62e09a8115c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -36,20 +36,17 @@ public RpcResponse(long requestId, byte[] response) { public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + 4 + response.length; } + public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(response.length); - buf.writeBytes(response); + Encoders.ByteArrays.encode(buf, response); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - int responseLen = buf.readInt(); - byte[] response = new byte[responseLen]; - buf.readBytes(response); + byte[] response = Encoders.ByteArrays.decode(buf); return new RpcResponse(requestId, response); } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 75c4a3981a240..009dbcf01323f 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -50,33 +50,6 @@ public static void closeQuietly(Closeable closeable) { } } - // TODO: Make this configurable, do not use Java serialization! - public static T deserialize(byte[] bytes) { - try { - ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); - Object out = is.readObject(); - is.close(); - return (T) out; - } catch (ClassNotFoundException e) { - throw new RuntimeException("Could not deserialize object", e); - } catch (IOException e) { - throw new RuntimeException("Could not deserialize object", e); - } - } - - // TODO: Make this configurable, do not use Java serialization! - public static byte[] serialize(Object object) { - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream os = new ObjectOutputStream(baos); - os.writeObject(object); - os.close(); - return baos.toByteArray(); - } catch (IOException e) { - throw new RuntimeException("Could not serialize object", e); - } - } - /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ public static int nonNegativeHash(Object obj) { if (obj == null) { return 0; } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 599cc6428c90e..cad76ab7aa54e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -17,10 +17,10 @@ package org.apache.spark.network.sasl; -import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -42,18 +42,14 @@ public SaslMessage(String appId, byte[] payload) { @Override public int encodedLength() { - // tag + appIdLength + appId + payloadLength + payload - return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); - byte[] idBytes = appId.getBytes(Charsets.UTF_8); - buf.writeInt(idBytes.length); - buf.writeBytes(idBytes); - buf.writeInt(payload.length); - buf.writeBytes(payload); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); } public static SaslMessage decode(ByteBuf buf) { @@ -62,14 +58,8 @@ public static SaslMessage decode(ByteBuf buf) { + " (maybe your client does not have SASL enabled?)"); } - int idLength = buf.readInt(); - byte[] idBytes = new byte[idLength]; - buf.readBytes(idBytes); - - int payloadLength = buf.readInt(); - byte[] payload = new byte[payloadLength]; - buf.readBytes(payload); - - return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 75ebf8c7b0604..a6db4b2abd6c9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -24,15 +24,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -62,12 +63,10 @@ public ExternalShuffleBlockHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - Object msgObj = JavaUtils.deserialize(message); - - logger.trace("Received message: " + msgObj); + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); - if (msgObj instanceof OpenShuffleBlocks) { - OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj; + if (msgObj instanceof OpenBlocks) { + OpenBlocks msg = (OpenBlocks) msgObj; List blocks = Lists.newArrayList(); for (String blockId : msg.blockIds) { @@ -75,8 +74,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } long streamId = streamManager.registerStream(blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(JavaUtils.serialize( - new ShuffleStreamHandle(streamId, msg.blockIds.length))); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; @@ -84,8 +82,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback.onSuccess(new byte[0]); } else { - throw new UnsupportedOperationException(String.format( - "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass())); + throw new UnsupportedOperationException("Unexpected message: " + msgObj); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 98fcfb82aa5d1..ffb7faa3dbdca 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -35,6 +35,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.JavaUtils; /** diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 27884b82c8cb9..6e8018b723dc6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -31,8 +31,8 @@ import org.apache.spark.network.sasl.SaslClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.util.TransportConf; /** @@ -91,8 +91,7 @@ public void fetchBlocks( public void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); } }; @@ -128,9 +127,8 @@ public void registerWithShuffleServer( ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); - byte[] registerExecutorMessage = - JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); - client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */); + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java deleted file mode 100644 index e79420ed8254f..0000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.Serializable; -import java.util.Arrays; - -import com.google.common.base.Objects; - -/** Messages handled by the {@link ExternalShuffleBlockHandler}. */ -public class ExternalShuffleMessages { - - /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */ - public static class OpenShuffleBlocks implements Serializable { - public final String appId; - public final String execId; - public final String[] blockIds; - - public OpenShuffleBlocks(String appId, String execId, String[] blockIds) { - this.appId = appId; - this.execId = execId; - this.blockIds = blockIds; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("blockIds", Arrays.toString(blockIds)) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof OpenShuffleBlocks) { - OpenShuffleBlocks o = (OpenShuffleBlocks) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Arrays.equals(blockIds, o.blockIds); - } - return false; - } - } - - /** Initial registration message between an executor and its local shuffle server. */ - public static class RegisterExecutor implements Serializable { - public final String appId; - public final String execId; - public final ExecutorShuffleInfo executorInfo; - - public RegisterExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - this.appId = appId; - this.execId = execId; - this.executorInfo = executorInfo; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId, executorInfo); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("executorInfo", executorInfo) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof RegisterExecutor) { - RegisterExecutor o = (RegisterExecutor) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(executorInfo, o.executorInfo); - } - return false; - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9e77a1f68c4b0..8ed2e0b39ad23 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -26,6 +26,9 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.JavaUtils; /** @@ -41,17 +44,21 @@ public class OneForOneBlockFetcher { private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private final TransportClient client; + private final OpenBlocks openMessage; private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private ShuffleStreamHandle streamHandle = null; + private StreamHandle streamHandle = null; public OneForOneBlockFetcher( TransportClient client, + String appId, + String execId, String[] blockIds, BlockFetchingListener listener) { this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); @@ -76,18 +83,18 @@ public void onFailure(int chunkIndex, Throwable e) { /** * Begins the fetching process, calling the listener with every block fetched. * The given message will be serialized with the Java serializer, and the RPC must return a - * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. */ - public void start(Object openBlocksMessage) { + public void start() { if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { try { - streamHandle = JavaUtils.deserialize(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java new file mode 100644 index 0000000000000..b4b13b8a6ef5d --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or + * by Spark's NettyBlockTransferService. + * + * At a high level: + * - OpenBlock is handled by both services, but only services shuffle files for the external + * shuffle service. It returns a StreamHandle. + * - UploadBlock is only handled by the NettyBlockTransferService. + * - RegisterExecutor is only handled by the external shuffle service. + */ +public abstract class BlockTransferMessage implements Encodable { + protected abstract Type type(); + + /** Preceding every serialized message is its type, which allows us to deserialize it. */ + public static enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + } + + // NB: Java does not support static methods in interfaces, so we must put this in a static class. + public static class Decoder { + /** Deserializes the 'type' byte followed by the message itself. */ + public static BlockTransferMessage fromByteArray(byte[] msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return OpenBlocks.decode(buf); + case 1: return UploadBlock.decode(buf); + case 2: return RegisterExecutor.decode(buf); + case 3: return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + /** Serializes the 'type' byte followed by the message itself. */ + public byte[] toByteArray() { + ByteBuf buf = Unpooled.buffer(encodedLength()); + buf.writeByte(type().id); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.array(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java similarity index 68% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index d45e64656a0e3..cadc8e8369c6a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -15,21 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; -import java.io.Serializable; import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** Contains all configuration necessary for locating the shuffle files of an executor. */ -public class ExecutorShuffleInfo implements Serializable { +public class ExecutorShuffleInfo implements Encodable { /** The base set of local directories that the executor stores its shuffle files in. */ - final String[] localDirs; + public final String[] localDirs; /** Number of subdirectories created within each localDir. */ - final int subDirsPerLocalDir; + public final int subDirsPerLocalDir; /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ - final String shuffleManager; + public final String shuffleManager; public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { this.localDirs = localDirs; @@ -61,4 +64,25 @@ public boolean equals(Object other) { } return false; } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(localDirs) + + 4 // int + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, localDirs); + buf.writeInt(subDirsPerLocalDir); + Encoders.Strings.encode(buf, shuffleManager); + } + + public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] localDirs = Encoders.StringArrays.decode(buf); + int subDirsPerLocalDir = buf.readInt(); + String shuffleManager = Encoders.Strings.decode(buf); + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java new file mode 100644 index 0000000000000..60485bace643c --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class OpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static OpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java new file mode 100644 index 0000000000000..38acae3b31d64 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Initial registration message between an executor and its local shuffle server. + * Returns nothing (empty bye array). + */ +public class RegisterExecutor extends BlockTransferMessage { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + protected Type type() { return Type.REGISTER_EXECUTOR; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + executorInfo.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + executorInfo.encode(buf); + } + + public static RegisterExecutor decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); + return new RegisterExecutor(appId, execId, executorShuffleInfo); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java similarity index 65% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 9c94691224328..21369c8cfb0d6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -15,26 +15,29 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; import java.io.Serializable; -import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" - * message. This is used by {@link OneForOneBlockFetcher}. + * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. */ -public class ShuffleStreamHandle implements Serializable { +public class StreamHandle extends BlockTransferMessage { public final long streamId; public final int numChunks; - public ShuffleStreamHandle(long streamId, int numChunks) { + public StreamHandle(long streamId, int numChunks) { this.streamId = streamId; this.numChunks = numChunks; } + @Override + protected Type type() { return Type.STREAM_HANDLE; } + @Override public int hashCode() { return Objects.hashCode(streamId, numChunks); @@ -50,11 +53,28 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof ShuffleStreamHandle) { - ShuffleStreamHandle o = (ShuffleStreamHandle) other; + if (other != null && other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; return Objects.equal(streamId, o.streamId) && Objects.equal(numChunks, o.numChunks); } return false; } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(numChunks); + } + + public static StreamHandle decode(ByteBuf buf) { + long streamId = buf.readLong(); + int numChunks = buf.readInt(); + return new StreamHandle(streamId, numChunks); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java new file mode 100644 index 0000000000000..38abe29cc585f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ +public class UploadBlock extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String blockId; + // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in + // this package. We should avoid this hack. + public final byte[] metadata; + public final byte[] blockData; + + /** + * @param metadata Meta-information about block, typically StorageLevel. + * @param blockData The actual block's bytes. + */ + public UploadBlock( + String appId, + String execId, + String blockId, + byte[] metadata, + byte[] blockData) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.metadata = metadata; + this.blockData = blockData; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(appId, execId, blockId); + return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .add("block size", blockData.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlock) { + UploadBlock o = (UploadBlock) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata) + && Arrays.equals(blockData, o.blockData); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata) + + Encoders.ByteArrays.encodedLength(blockData); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + Encoders.ByteArrays.encode(buf, blockData); + } + + public static UploadBlock decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + byte[] blockData = Encoders.ByteArrays.decode(buf); + return new UploadBlock(appId, execId, blockId, metadata, blockData); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java similarity index 55% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java rename to network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index ee9482b49cfc3..d65de9ca550a3 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -21,31 +21,24 @@ import static org.junit.Assert.*; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.*; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - -public class ShuffleMessagesSuite { +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { @Test public void serializeOpenShuffleBlocks() { - OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2", - new String[] { "block0", "block1" }); - OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); } - @Test - public void serializeRegisterExecutor() { - RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( - new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")); - RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); - } - - @Test - public void serializeShuffleStreamHandle() { - ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16); - ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7939cb4d32690..3f9fe1681cf27 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -24,8 +24,6 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; @@ -36,7 +34,12 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; public class ExternalShuffleBlockHandlerSuite { TransportClient client = mock(TransportClient.class); @@ -57,8 +60,7 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = JavaUtils.serialize( - new RegisterExecutor("app0", "exec1", config)); + byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); handler.receive(client, registerMessage, callback); verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); @@ -75,9 +77,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocksMessage = JavaUtils.serialize( - new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" })); - handler.receive(client, openBlocksMessage, callback); + byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + handler.receive(client, openBlocks, callback); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); @@ -85,7 +86,8 @@ public void testOpenShuffleBlocks() { verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); - ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); assertEquals(2, handle.numChunks); ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); @@ -100,18 +102,17 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 }; + byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; try { - handler.receive(client, unserializableMessage, callback); + handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); } catch (Exception e) { // pass } - byte[] unexpectedMessage = JavaUtils.serialize( - new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort")); + byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); try { - handler.receive(client, unexpectedMessage, callback); + handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); } catch (UnsupportedOperationException e) { // pass diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 3bea5b0f253c6..687bde59fdae4 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 848c88f743d50..8afceab1d585a 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index c18346f6966d6..842741e3d354f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -40,7 +40,9 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; public class OneForOneBlockFetcherSuite { @Test @@ -119,17 +121,19 @@ public void testEmptyBlockFetch() { private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( + (byte[]) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size()))); - assertEquals("OpenZeBlocks", message); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); @@ -161,7 +165,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); - fetcher.start("OpenZeBlocks"); + fetcher.start(); return listener; } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 337b5c7bdb5da..76639114df5d9 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -25,6 +25,8 @@ import com.google.common.io.Files; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + /** * Manages some sort- and hash-based shuffle data, including the creation * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. From 636d7bcc96b912f5b5caa91110cd55b55fa38ad8 Mon Sep 17 00:00:00 2001 From: wangfei Date: Fri, 7 Nov 2014 11:43:35 -0800 Subject: [PATCH 29/68] [SQL][DOC][Minor] Spark SQL Hive now support dynamic partitioning Author: wangfei Closes #3127 from scwf/patch-9 and squashes the following commits: e39a560 [wangfei] now support dynamic partitioning --- docs/sql-programming-guide.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e399fecbbc78c..ffcce2c588879 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1059,7 +1059,6 @@ in Hive deployments. **Major Hive Features** -* Spark SQL does not currently support inserting to tables using dynamic partitioning. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. From 86e9eaa3f0ec23cb38bce67585adb2d5f484f4ee Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 7 Nov 2014 11:45:25 -0800 Subject: [PATCH 30/68] [SPARK-4225][SQL] Resorts to SparkContext.version to inspect Spark version This PR resorts to `SparkContext.version` rather than META-INF/MANIFEST.MF in the assembly jar to inspect Spark version. Currently, when built with Maven, the MANIFEST.MF file in the assembly jar is incorrectly replaced by Guava 15.0 MANIFEST.MF, probably because of the assembly/shading tricks. Another related PR is #3103, which tries to fix the MANIFEST issue. Author: Cheng Lian Closes #3105 from liancheng/spark-4225 and squashes the following commits: d9585e1 [Cheng Lian] Resorts to SparkContext.version to inspect Spark version --- .../scala/org/apache/spark/util/Utils.scala | 24 ++++++------------- .../thriftserver/SparkSQLCLIService.scala | 12 ++++------ 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a14d6125484fe..6b85c03da533c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,10 +21,8 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.jar.Attributes.Name -import java.util.{Properties, Locale, Random, UUID} -import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} -import java.util.jar.{Manifest => JarManifest} +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.{Locale, Properties, Random, UUID} import scala.collection.JavaConversions._ import scala.collection.Map @@ -38,11 +36,11 @@ import com.google.common.io.{ByteStreams, Files} import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration -import org.apache.log4j.PropertyConfigurator import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ -import tachyon.client.{TachyonFile,TachyonFS} +import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -352,8 +350,8 @@ private[spark] object Utils extends Logging { * Download a file to target directory. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. * - * If `useCache` is true, first attempts to fetch the file to a local cache that's shared - * across executors running the same application. `useCache` is used mainly for + * If `useCache` is true, first attempts to fetch the file to a local cache that's shared + * across executors running the same application. `useCache` is used mainly for * the executors, and not in local mode. * * Throws SparkException if the target file already exists and has different contents than @@ -400,7 +398,7 @@ private[spark] object Utils extends Logging { } else { doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf) } - + // Decompress the file if it's a .tar or .tar.gz if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { logInfo("Untarring " + fileName) @@ -1776,13 +1774,6 @@ private[spark] object Utils extends Logging { s"$libraryPathEnvName=$libraryPath$ampersand" } - lazy val sparkVersion = - SparkContext.jarOfObject(this).map { path => - val manifestUrl = new URL(s"jar:file:$path!/META-INF/MANIFEST.MF") - val manifest = new JarManifest(manifestUrl.openStream()) - manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) - }.getOrElse("Unknown") - /** * Return the value of a config either through the SparkConf or the Hadoop configuration * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf @@ -1796,7 +1787,6 @@ private[spark] object Utils extends Logging { sparkValue } } - } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ecfb74473e921..499e077d7294a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -17,18 +17,16 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.jar.Attributes.Name - -import scala.collection.JavaConversions._ - import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConversions._ + import org.apache.commons.logging.Log -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ @@ -50,7 +48,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims().isSecurityEnabled()) { + if (ShimLoader.getHadoopShims.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) @@ -68,7 +66,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) getInfoType match { case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") - case GetInfoType.CLI_DBMS_VER => new GetInfoValue(Utils.sparkVersion) + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version) case _ => super.getInfo(sessionHandle, getInfoType) } } From 8154ed7df6c5407e638f465d3bd86b43f36216ef Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 7 Nov 2014 11:51:20 -0800 Subject: [PATCH 31/68] [SQL] Support ScalaReflection of schema in different universes Author: Michael Armbrust Closes #3096 from marmbrus/reflectionContext and squashes the following commits: adc221f [Michael Armbrust] Support ScalaReflection of schema in different universes --- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9cda373623cb5..71034c2c43c77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -26,14 +26,26 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types.decimal.Decimal + /** - * Provides experimental support for generating catalyst schemas for scala objects. + * A default version of ScalaReflection that uses the runtime universe. */ -object ScalaReflection { +object ScalaReflection extends ScalaReflection { + val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe +} + +/** + * Support for generating catalyst schemas for scala objects. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + import universe._ + // The Predef.Map is scala.collection.immutable.Map. // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - import scala.reflect.runtime.universe._ case class Schema(dataType: DataType, nullable: Boolean) From 68609c51ad1ab2def302df3c4a1c0bc1ec6e1075 Mon Sep 17 00:00:00 2001 From: Jacky Li Date: Fri, 7 Nov 2014 11:52:08 -0800 Subject: [PATCH 32/68] [SQL] Modify keyword val location according to ordering 'DOUBLE' should be moved before 'ELSE' according to the ordering convension Author: Jacky Li Closes #3080 from jackylk/patch-5 and squashes the following commits: 3c11df7 [Jacky Li] [SQL] Modify keyword val location according to ordering --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 5e613e0f18ba6..affef276c2a88 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -55,10 +55,10 @@ class SqlParser extends AbstractSparkSQLParser { protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val DOUBLE = Keyword("DOUBLE") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") - protected val DOUBLE = Keyword("DOUBLE") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") From 14c54f1876fcf91b5c10e80be2df5421c7328557 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 7 Nov 2014 11:56:40 -0800 Subject: [PATCH 33/68] [SPARK-4213][SQL] ParquetFilters - No support for LT, LTE, GT, GTE operators Following description is quoted from JIRA: When I issue a hql query against a HiveContext where my predicate uses a column of string type with one of LT, LTE, GT, or GTE operator, I get the following error: scala.MatchError: StringType (of class org.apache.spark.sql.catalyst.types.StringType$) Looking at the code in org.apache.spark.sql.parquet.ParquetFilters, StringType is absent from the corresponding functions for creating these filters. To reproduce, in a Hive 0.13.1 shell, I created the following table (at a specified DB): create table sparkbug ( id int, event string ) stored as parquet; Insert some sample data: insert into table sparkbug select 1, '2011-06-18' from limit 1; insert into table sparkbug select 2, '2012-01-01' from limit 1; Launch a spark shell and create a HiveContext to the metastore where the table above is located. import org.apache.spark.sql._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveContext val hc = new HiveContext(sc) hc.setConf("spark.sql.shuffle.partitions", "10") hc.setConf("spark.sql.hive.convertMetastoreParquet", "true") hc.setConf("spark.sql.parquet.compression.codec", "snappy") import hc._ hc.hql("select * from .sparkbug where event >= '2011-12-01'") A scala.MatchError will appear in the output. Author: Kousuke Saruta Closes #3083 from sarutak/SPARK-4213 and squashes the following commits: 4ab6e56 [Kousuke Saruta] WIP b6890c6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-4213 9a1fae7 [Kousuke Saruta] Fixed ParquetFilters so that compare Strings --- .../spark/sql/parquet/ParquetFilters.scala | 335 +++++++++++++++++- .../spark/sql/parquet/ParquetQuerySuite.scala | 40 +++ 2 files changed, 364 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 517a5cf0029ed..1e67799e8399a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.parquet import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} import org.apache.hadoop.conf.Configuration +import parquet.common.schema.ColumnPath import parquet.filter2.compat.FilterCompat import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterPredicate -import parquet.filter2.predicate.FilterApi +import parquet.filter2.predicate.Operators.{Column, SupportsLtGt} +import parquet.filter2.predicate.{FilterApi, FilterPredicate} import parquet.filter2.predicate.FilterApi._ import parquet.io.api.Binary import parquet.column.ColumnReader @@ -33,9 +35,11 @@ import com.google.common.io.BaseEncoding import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.parquet.ParquetColumns._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -50,15 +54,25 @@ private[sql] object ParquetFilters { if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null } - def createFilter(expression: Expression): Option[CatalystFilter] ={ + def createFilter(expression: Expression): Option[CatalystFilter] = { def createEqualityFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case BooleanType => - ComparisonFilter.createBooleanFilter( + ComparisonFilter.createBooleanEqualityFilter( name, - literal.value.asInstanceOf[Boolean], + literal.value.asInstanceOf[Boolean], + predicate) + case ByteType => + new ComparisonFilter( + name, + FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), predicate) case IntegerType => new ComparisonFilter( @@ -81,18 +95,49 @@ private[sql] object ParquetFilters { FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) case StringType => - ComparisonFilter.createStringFilter( + ComparisonFilter.createStringEqualityFilter( name, literal.value.asInstanceOf[String], predicate) + case BinaryType => + ComparisonFilter.createBinaryEqualityFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.eq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createLessThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - new ComparisonFilter( + new ComparisonFilter( name, FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) @@ -111,11 +156,47 @@ private[sql] object ParquetFilters { name, FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringLessThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.lt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createLessThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -136,12 +217,48 @@ private[sql] object ParquetFilters { name, FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.ltEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } // TODO: combine these two types somehow? def createGreaterThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -162,11 +279,47 @@ private[sql] object ParquetFilters { name, FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringGreaterThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createGreaterThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -187,6 +340,32 @@ private[sql] object ParquetFilters { name, FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gtEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } /** @@ -221,9 +400,9 @@ private[sql] object ParquetFilters { case _ => None } } - case p @ EqualTo(left: Literal, right: NamedExpression) => + case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType => Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) => + case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType => Some(createEqualityFilter(left.name, right, p)) case p @ LessThan(left: Literal, right: NamedExpression) => Some(createLessThanFilter(right.name, left, p)) @@ -363,7 +542,7 @@ private[parquet] case class AndFilter( } private[parquet] object ComparisonFilter { - def createBooleanFilter( + def createBooleanEqualityFilter( columnName: String, value: Boolean, predicate: CatalystPredicate): CatalystFilter = @@ -372,7 +551,7 @@ private[parquet] object ComparisonFilter { FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), predicate) - def createStringFilter( + def createStringEqualityFilter( columnName: String, value: String, predicate: CatalystPredicate): CatalystFilter = @@ -380,4 +559,138 @@ private[parquet] object ComparisonFilter { columnName, FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), predicate) + + def createStringLessThanFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringLessThanOrEqualFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringGreaterThanFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringGreaterThanOrEqualFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createBinaryEqualityFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) +} + +private[spark] object ParquetColumns { + + def byteColumn(columnPath: String): ByteColumn = { + new ByteColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ByteColumn(columnPath: ColumnPath) + extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt + + def shortColumn(columnPath: String): ShortColumn = { + new ShortColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ShortColumn(columnPath: ColumnPath) + extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt + + + def dateColumn(columnPath: String): DateColumn = { + new DateColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DateColumn(columnPath: ColumnPath) + extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt + + def timestampColumn(columnPath: String): TimestampColumn = { + new TimestampColumn(ColumnPath.fromDotString(columnPath)) + } + + final class TimestampColumn(columnPath: ColumnPath) + extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt + + def decimalColumn(columnPath: String): DecimalColumn = { + new DecimalColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DecimalColumn(columnPath: ColumnPath) + extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt + + final class WrappedDate(val date: Date) extends Comparable[WrappedDate] { + + override def compareTo(other: WrappedDate): Int = { + date.compareTo(other.date) + } + } + + final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] { + + override def compareTo(other: WrappedTimestamp): Int = { + timestamp.compareTo(other.timestamp) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08d9da27f1b11..3cccafe92d4f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -619,6 +619,46 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA fail(s"optional Int value in result row $i should be ${6*i}") } } + + val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") + assert( + query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result12 = query12.collect() + assert(result12.size === 54) + assert(result12(0).getString(2) == "6") + assert(result12(4).getString(2) == "50") + assert(result12(53).getString(2) == "99") + + val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") + assert( + query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result13 = query13.collect() + assert(result13.size === 53) + assert(result13(0).getString(2) == "6") + assert(result13(4).getString(2) == "51") + assert(result13(52).getString(2) == "99") + + val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") + assert( + query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result14 = query14.collect() + assert(result14.size === 148) + assert(result14(0).getString(2) == "0") + assert(result14(46).getString(2) == "50") + assert(result14(147).getString(2) == "200") + + val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") + assert( + query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result15 = query15.collect() + assert(result15.size === 147) + assert(result15(0).getString(2) == "0") + assert(result15(46).getString(2) == "100") + assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { From 60ab80f501b8384ddf48a9ac0ba0c2b9eb548b28 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 7 Nov 2014 12:15:53 -0800 Subject: [PATCH 34/68] [SPARK-4272] [SQL] Add more unwrapper functions for primitive type in TableReader Currently, the data "unwrap" only support couple of primitive types, not all, it will not cause exception, but may get some performance in table scanning for the type like binary, date, timestamp, decimal etc. Author: Cheng Hao Closes #3136 from chenghao-intel/table_reader and squashes the following commits: fffb729 [Cheng Hao] fix bug for retrieving the timestamp object e9c97a4 [Cheng Hao] Add more unwrapper functions for primitive type in TableReader --- .../apache/spark/sql/hive/HiveInspectors.scala | 4 ---- .../org/apache/spark/sql/hive/TableReader.scala | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 58815daa82276..bdc7e1dac1922 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -115,10 +115,6 @@ private[hive] trait HiveInspectors { } - /** - * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. - */ /** * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index e49f0957d188a..f60bc3788e3e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -290,6 +290,21 @@ private[hive] object HadoopTableReader extends HiveInspectors { (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi: HiveVarcharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + case oi: HiveDecimalObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) + case oi: TimestampObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + case oi: DateObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value)) + case oi: BinaryObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value)) case oi => (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) } From a6405c5ddcda112f8efd7d50d8e5f44f78a0fa41 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 7 Nov 2014 12:30:47 -0800 Subject: [PATCH 35/68] [SPARK-4270][SQL] Fix Cast from DateType to DecimalType. `Cast` from `DateType` to `DecimalType` throws `NullPointerException`. Author: Takuya UESHIN Closes #3134 from ueshin/issues/SPARK-4270 and squashes the following commits: 7394e4b [Takuya UESHIN] Fix Cast from DateType to DecimalType. --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/ExpressionEvaluationSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 22009666196a1..55319e7a79103 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -281,7 +281,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive + buildCast[Date](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 6bfa0dbd65ba7..918996f11da2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -412,6 +412,8 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(d, LongType), null) checkEvaluation(Cast(d, FloatType), null) checkEvaluation(Cast(d, DoubleType), null) + checkEvaluation(Cast(d, DecimalType.Unlimited), null) + checkEvaluation(Cast(d, DecimalType(10, 2)), null) checkEvaluation(Cast(d, StringType), "1970-01-01") checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") } From ac70c972a51952f801fd02dd5962c0a0c1aba8f8 Mon Sep 17 00:00:00 2001 From: Matthew Taylor Date: Fri, 7 Nov 2014 12:53:08 -0800 Subject: [PATCH 36/68] [SPARK-4203][SQL] Partition directories in random order when inserting into hive table When doing an insert into hive table with partitions the folders written to the file system are in a random order instead of the order defined in table creation. Seems that the loadPartition method in Hive.java has a Map parameter but expects to be called with a map that has a defined ordering such as LinkedHashMap. Working on a test but having intillij problems Author: Matthew Taylor Closes #3076 from tbfenet/partition_dir_order_problem and squashes the following commits: f1b9a52 [Matthew Taylor] Comment format fix bca709f [Matthew Taylor] review changes 0e50f6b [Matthew Taylor] test fix 99f1a31 [Matthew Taylor] partition ordering fix 369e618 [Matthew Taylor] partition ordering fix --- .../hive/execution/InsertIntoHiveTable.scala | 13 +++++-- .../sql/hive/InsertIntoHiveTableSuite.scala | 34 +++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 74b4e7aaa47a5..81390f626726c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.util + import scala.collection.JavaConversions._ import org.apache.hadoop.hive.common.`type`.HiveVarchar @@ -203,6 +205,13 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { + + // loadPartition call orders directories created on the iteration order of the this map + val orderedPartitionSpec = new util.LinkedHashMap[String,String]() + table.hiveQlTable.getPartCols().foreach{ + entry=> + orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -214,7 +223,7 @@ case class InsertIntoHiveTable( db.loadDynamicPartitions( outputPath, qualifiedTableName, - partitionSpec, + orderedPartitionSpec, overwrite, numDynamicPartitions, holdDDLTime, @@ -224,7 +233,7 @@ case class InsertIntoHiveTable( db.loadPartition( outputPath, qualifiedTableName, - partitionSpec, + orderedPartitionSpec, overwrite, holdDDLTime, inheritTableSpecs, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 18dc937dd2b27..5dbfb923139fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql._ +import java.io.File + +import com.google.common.io.Files +import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive /* Implicits */ @@ -91,4 +93,32 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE hiveTableWithMapValue") } + + test("SPARK-4203:random partition directory order") { + createTable[TestData]("tmp_table") + val tmpDir = Files.createTempDir() + sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + def listFolders(path: File, acc: List[String]): List[List[String]] = { + val dir = path.listFiles() + val folders = dir.filter(_.isDirectory).toList + if (folders.isEmpty) { + List(acc.reverse) + } else { + folders.flatMap(x => listFolders(x, x.getName :: acc)) + } + } + val expected = List( + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil + ) + assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + sql("DROP TABLE table_with_partition") + sql("DROP TABLE tmp_table") + } } From d6e55524437026c0c76addeba8f99249a8316716 Mon Sep 17 00:00:00 2001 From: wangfei Date: Fri, 7 Nov 2014 12:55:11 -0800 Subject: [PATCH 37/68] [SPARK-4292][SQL] Result set iterator bug in JDBC/ODBC select * from src, get the wrong result set as follows: ``` ... | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 309 | val_309 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | | 97 | val_97 | ... ``` Author: wangfei Closes #3149 from scwf/SPARK-4292 and squashes the following commits: 1574a43 [wangfei] using result.collect 8b2d845 [wangfei] adding test f64eddf [wangfei] result set iter bug --- .../thriftserver/HiveThriftServer2Suite.scala | 23 +++++++++++++++++++ .../spark/sql/hive/thriftserver/Shim12.scala | 5 ++-- .../spark/sql/hive/thriftserver/Shim13.scala | 5 ++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 65d910a0c3ffc..bba29b2bdca4d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -267,4 +267,27 @@ class HiveThriftServer2Suite extends FunSuite with Logging { assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") } } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) == key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 8077d0ec46fd7..e3ba9914c6cc0 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -202,13 +202,12 @@ private[hive] class SparkExecuteStatementOperation( hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) } iter = { - val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - resultRdd.toLocalIterator + result.toLocalIterator } else { - resultRdd.collect().iterator + result.collect().iterator } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 2c1983de1d0d5..f2ceba828296b 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -87,13 +87,12 @@ private[hive] class SparkExecuteStatementOperation( val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = { - val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - resultRdd.toLocalIterator + result.toLocalIterator } else { - resultRdd.collect().iterator + result.collect().iterator } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray From 7c9ec529a3483fab48f728481dd1d3663369e50a Mon Sep 17 00:00:00 2001 From: xiao321 <1042460381@qq.com> Date: Fri, 7 Nov 2014 12:56:49 -0800 Subject: [PATCH 38/68] Update JavaCustomReceiver.java MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 数组下标越界 Author: xiao321 <1042460381@qq.com> Closes #3153 from xiao321/patch-1 and squashes the following commits: 0ed17b5 [xiao321] Update JavaCustomReceiver.java --- .../org/apache/spark/examples/streaming/JavaCustomReceiver.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 981bc4f0613a9..99df259b4e8e6 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -70,7 +70,7 @@ public static void main(String[] args) { // Create a input stream with the custom receiver on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.receiverStream( - new JavaCustomReceiver(args[1], Integer.parseInt(args[2]))); + new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { From 5923dd986ba26d0fcc8707dd8d16863f1c1005cb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 7 Nov 2014 13:08:25 -0800 Subject: [PATCH 39/68] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #3016 (close requested by 'andrewor14') Closes #2798 (close requested by 'andrewor14') Closes #2864 (close requested by 'andrewor14') Closes #3154 (close requested by 'JoshRosen') Closes #3156 (close requested by 'JoshRosen') Closes #214 (close requested by 'kayousterhout') Closes #2584 (close requested by 'andrewor14') From 7779109796c90d789464ab0be35917f963bbe867 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 7 Nov 2014 20:53:03 -0800 Subject: [PATCH 40/68] [SPARK-4304] [PySpark] Fix sort on empty RDD This PR fix sortBy()/sortByKey() on empty RDD. This should be back ported into 1.1/1.2 Author: Davies Liu Closes #3162 from davies/fix_sort and squashes the following commits: 84f64b7 [Davies Liu] add tests 52995b5 [Davies Liu] fix sortByKey() on empty RDD --- python/pyspark/rdd.py | 2 ++ python/pyspark/tests.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 879655dc53f4a..08d047402625f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -521,6 +521,8 @@ def sortPartition(iterator): # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them rddSize = self.count() + if not rddSize: + return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9f625c5c6ca48..491e445a216bf 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -649,6 +649,9 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + def test_sample(self): rdd = self.sc.parallelize(range(0, 100), 4) wo = rdd.sample(False, 0.1, 2).collect() From 7e9d975676d56ace0e84c2200137e4cd4eba074a Mon Sep 17 00:00:00 2001 From: Michelangelo D'Agostino Date: Fri, 7 Nov 2014 22:53:01 -0800 Subject: [PATCH 41/68] [MLLIB] [PYTHON] SPARK-4221: Expose nonnegative ALS in the python API SPARK-1553 added alternating nonnegative least squares to MLLib, however it's not possible to access it via the python API. This pull request resolves that. Author: Michelangelo D'Agostino Closes #3095 from mdagost/python_nmf and squashes the following commits: a6743ad [Michelangelo D'Agostino] Use setters instead of static methods in PythonMLLibAPI. Remove the new static methods I added. Set seed in tests. Change ratings to ratingsRDD in both train and trainImplicit for consistency. 7cffd39 [Michelangelo D'Agostino] Swapped nonnegative and seed in a few more places. 3fdc851 [Michelangelo D'Agostino] Moved seed to the end of the python parameter list. bdcc154 [Michelangelo D'Agostino] Change seed type to java.lang.Long so that it can handle null. cedf043 [Michelangelo D'Agostino] Added in ability to set the seed from python and made that play nice with the nonnegative changes. Also made the python ALS tests more exact. a72fdc9 [Michelangelo D'Agostino] Expose nonnegative ALS in the python API. --- .../mllib/api/python/PythonMLLibAPI.scala | 39 +++++++++++++++--- python/pyspark/mllib/recommendation.py | 40 ++++++++++++------- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d832ae34b55e4..70d7138e3060f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -275,12 +275,25 @@ class PythonMLLibAPI extends Serializable { * the Py4J documentation. */ def trainALSModel( - ratings: JavaRDD[Rating], + ratingsJRDD: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, - blocks: Int): MatrixFactorizationModel = { - new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks)) + blocks: Int, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** @@ -295,9 +308,23 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int, - alpha: Double): MatrixFactorizationModel = { - new MatrixFactorizationModelWrapper( - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)) + alpha: Double, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setImplicitPrefs(true) + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setAlpha(alpha) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e8b998414d319..e26b152e0cdfd 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -44,31 +44,39 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) - >>> model = ALS.trainImplicit(ratings, 1) - >>> model.predict(2,2) is not None - True + >>> model = ALS.trainImplicit(ratings, 1, seed=10) + >>> model.predict(2,2) + 0.4473... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1) - >>> model.predictAll(testset).count() == 2 - True + >>> model = ALS.train(ratings, 1, seed=10) + >>> model.predictAll(testset).collect() + [Rating(1, 1, 1), Rating(1, 2, 1)] - >>> model = ALS.train(ratings, 4) - >>> model.userFeatures().count() == 2 - True + >>> model = ALS.train(ratings, 4, seed=10) + >>> model.userFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] >>> len(latents) == 4 True - >>> model.productFeatures().count() == 2 - True + >>> model.productFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] >>> len(latents) == 4 True + + >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 3.735... + + >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 0.4473... """ def predict(self, user, product): return self._java_model.predict(user, product) @@ -101,15 +109,17 @@ def _prepare(cls, ratings): return _to_java_object_rdd(ratings, True) @classmethod - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, + seed=None): model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, - lambda_, blocks) + lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, + nonnegative=False, seed=None): model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, - iterations, lambda_, blocks, alpha) + iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) From 7afc8564f33eb2868f458f85046f59a51b516ed6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Nov 2014 23:16:13 -0800 Subject: [PATCH 42/68] [SPARK-4291][Build] Rename network module projects The names of the recently introduced network modules are inconsistent with those of the other modules in the project. We should just drop the "Code" suffix since it doesn't sacrifice any meaning, especially before they get into an official release. ``` [INFO] Reactor Build Order: [INFO] [INFO] Spark Project Parent POM [INFO] Spark Project Common Network Code [INFO] Spark Project Shuffle Streaming Service Code [INFO] Spark Project Core [INFO] Spark Project Bagel [INFO] Spark Project GraphX [INFO] Spark Project Streaming [INFO] Spark Project Catalyst [INFO] Spark Project SQL [INFO] Spark Project ML Library [INFO] Spark Project Tools [INFO] Spark Project Hive [INFO] Spark Project REPL [INFO] Spark Project YARN Parent POM [INFO] Spark Project YARN Stable API [INFO] Spark Project Assembly [INFO] Spark Project External Twitter [INFO] Spark Project External Kafka [INFO] Spark Project External Flume Sink [INFO] Spark Project External Flume [INFO] Spark Project External ZeroMQ [INFO] Spark Project External MQTT [INFO] Spark Project Examples [INFO] Spark Project Yarn Shuffle Service Code ``` Author: Andrew Or Closes #3148 from andrewor14/build-drop-code and squashes the following commits: eac839b [Andrew Or] Network -> Networking d01ad47 [Andrew Or] Rename network module project names --- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/network/common/pom.xml b/network/common/pom.xml index 6144548a8f998..8b24ebf1ba1f2 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-network-common_2.10 jar - Spark Project Common Network Code + Spark Project Networking http://spark.apache.org/ network-common diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index fe5681d463499..27c8467687f10 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-network-shuffle_2.10 jar - Spark Project Shuffle Streaming Service Code + Spark Project Shuffle Streaming Service http://spark.apache.org/ network-shuffle diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index e60d8c1f7876c..6e6f6f3e79296 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-network-yarn_2.10 jar - Spark Project Yarn Shuffle Service Code + Spark Project YARN Shuffle Service http://spark.apache.org/ network-yarn From 4af5c7e24455246c61c1f3c22225507e720d721d Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sat, 8 Nov 2014 13:03:51 -0800 Subject: [PATCH 43/68] [Minor] [Core] Don't NPE on closeQuietly(null) Author: Aaron Davidson Closes #3166 from aarondav/closeQuietlyer and squashes the following commits: 78096b5 [Aaron Davidson] Don't NPE on closeQuietly(null) --- .../main/java/org/apache/spark/network/util/JavaUtils.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 009dbcf01323f..bf8a1fc42fc6d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -44,7 +44,9 @@ public class JavaUtils { /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { - closeable.close(); + if (closeable != null) { + closeable.close(); + } } catch (IOException e) { logger.error("IOException should not have been thrown.", e); } From 7b41b17f3296eea3282efbdceb6b28baf128287d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 8 Nov 2014 18:10:23 -0800 Subject: [PATCH 44/68] [SPARK-4301] StreamingContext should not allow start() to be called after calling stop() In Spark 1.0.0+, calling `stop()` on a StreamingContext that has not been started is a no-op which has no side-effects. This allows users to call `stop()` on a fresh StreamingContext followed by `start()`. I believe that this almost always indicates an error and is not behavior that we should support. Since we don't allow `start() stop() start()` then I don't think it makes sense to allow `stop() start()`. The current behavior can lead to resource leaks when StreamingContext constructs its own SparkContext: if I call `stop(stopSparkContext=True)`, then I expect StreamingContext's underlying SparkContext to be stopped irrespective of whether the StreamingContext has been started. This is useful when writing unit test fixtures. Prior discussions: - https://github.com/apache/spark/pull/3053#discussion-diff-19710333R490 - https://github.com/apache/spark/pull/3121#issuecomment-61927353 Author: Josh Rosen Closes #3160 from JoshRosen/SPARK-4301 and squashes the following commits: dbcc929 [Josh Rosen] Address more review comments bdbe5da [Josh Rosen] Stop SparkContext after stopping scheduler, not before. 03e9c40 [Josh Rosen] Always stop SparkContext, even if stop(false) has already been called. 832a7f4 [Josh Rosen] Address review comment 5142517 [Josh Rosen] Add tests; improve Scaladoc. 813e471 [Josh Rosen] Revert workaround added in https://github.com/apache/spark/pull/3053/files#diff-e144dbee130ed84f9465853ddce65f8eR49 5558e70 [Josh Rosen] StreamingContext.stop() should stop SparkContext even if StreamingContext has not been started yet. --- .../spark/streaming/StreamingContext.scala | 38 ++++++++++--------- .../streaming/StreamingContextSuite.scala | 25 +++++++++--- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 23d6d1c5e50fa..54b219711efb9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -436,10 +436,10 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. + * + * @throws SparkException if the context has already been started or stopped. */ def start(): Unit = synchronized { - // Throw exception if the context has already been started once - // or if a stopped context is being started again if (state == Started) { throw new SparkException("StreamingContext has already been started") } @@ -472,8 +472,10 @@ class StreamingContext private[streaming] ( /** * Stop the execution of the streams immediately (does not wait for all received data * to be processed). - * @param stopSparkContext Stop the associated SparkContext or not * + * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * will be stopped regardless of whether this StreamingContext has been + * started. */ def stop(stopSparkContext: Boolean = true): Unit = synchronized { stop(stopSparkContext, false) @@ -482,25 +484,27 @@ class StreamingContext private[streaming] ( /** * Stop the execution of the streams, with option of ensuring all received data * has been processed. - * @param stopSparkContext Stop the associated SparkContext or not - * @param stopGracefully Stop gracefully by waiting for the processing of all + * + * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * will be stopped regardless of whether this StreamingContext has been + * started. + * @param stopGracefully if true, stops gracefully by waiting for the processing of all * received data to be completed */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - // Warn (but not fail) if context is stopped twice, - // or context is stopped before starting - if (state == Initialized) { - logWarning("StreamingContext has not been started yet") - return + state match { + case Initialized => logWarning("StreamingContext has not been started yet") + case Stopped => logWarning("StreamingContext has already been stopped") + case Started => + scheduler.stop(stopGracefully) + logInfo("StreamingContext stopped successfully") + waiter.notifyStop() } - if (state == Stopped) { - logWarning("StreamingContext has already been stopped") - return - } // no need to throw an exception as its okay to stop twice - scheduler.stop(stopGracefully) - logInfo("StreamingContext stopped successfully") - waiter.notifyStop() + // Even if the streaming context has not been started, we still need to stop the SparkContext. + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). if (stopSparkContext) sc.stop() + // The state should always be Stopped after calling `stop()`, even if we haven't started yet: state = Stopped } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index f47772947d67c..4b49c4d251645 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -46,10 +46,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w after { if (ssc != null) { ssc.stop() - if (ssc.sc != null) { - // Calling ssc.stop() does not always stop the associated SparkContext. - ssc.sc.stop() - } ssc = null } if (sc != null) { @@ -137,11 +133,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.stop() } - test("stop before start and start after stop") { + test("stop before start") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() // stop before start should not throw exception - ssc.start() + } + + test("start after stop") { + // Regression test for SPARK-4301 + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register() ssc.stop() intercept[SparkException] { ssc.start() // start after stop should throw exception @@ -161,6 +162,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.stop() } + test("stop(stopSparkContext=true) after stop(stopSparkContext=false)") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register() + ssc.stop(stopSparkContext = false) + assert(ssc.sc.makeRDD(1 to 100).collect().size === 100) + ssc.stop(stopSparkContext = true) + // Check that the SparkContext is actually stopped: + intercept[Exception] { + ssc.sc.makeRDD(1 to 100).collect() + } + } + test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.cleaner.ttl", "3600") From 8c99a47a4f0369ff3c1ecaeb860fa61ee789e987 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Nov 2014 17:40:48 -0800 Subject: [PATCH 45/68] SPARK-971 [DOCS] Link to Confluence wiki from project website / documentation This is a trivial change to add links to the wiki from `README.md` and the main docs page. It is already linked to from spark.apache.org. Author: Sean Owen Closes #3169 from srowen/SPARK-971 and squashes the following commits: dcb84d0 [Sean Owen] Add link to wiki from README, docs home page --- README.md | 3 ++- docs/index.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9916ac7b1ae8e..8d57d50da96c9 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html). +guide, on the [project web page](http://spark.apache.org/documentation.html) +and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). This README file only contains basic setup instructions. ## Building Spark diff --git a/docs/index.md b/docs/index.md index edd622ec90f64..171d6ddad62f3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -112,6 +112,7 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) +* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), From d1362659ef5d62db2c9ff0d2a24639abcef4e118 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Nov 2014 17:42:08 -0800 Subject: [PATCH 46/68] SPARK-1344 [DOCS] Scala API docs for top methods Use "k" in javadoc of top and takeOrdered to avoid confusion with type K in pair RDDs. I think this resolves the discussion in SPARK-1344. Author: Sean Owen Closes #3168 from srowen/SPARK-1344 and squashes the following commits: 6963fcc [Sean Owen] Use "k" in javadoc of top and takeOrdered to avoid confusion with type K in pair RDDs --- .../org/apache/spark/api/java/JavaRDDLike.scala | 16 ++++++++-------- .../main/scala/org/apache/spark/rdd/RDD.scala | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index efb8978f7ce12..5a8e5bb1f721a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -493,9 +493,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD as defined by + * Returns the top k (largest) elements from this RDD as defined by * the specified Comparator[T]. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -507,9 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD using the + * Returns the top k (largest) elements from this RDD using the * natural ordering for T. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def top(num: Int): JList[T] = { @@ -518,9 +518,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD as defined by + * Returns the first k (smallest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -552,9 +552,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD using the + * Returns the first k (smallest) elements from this RDD using the * natural ordering for T while maintain the order. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def takeOrdered(num: Int): JList[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index c169b2d3fe97f..716f2dd17733b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1096,7 +1096,7 @@ abstract class RDD[T: ClassTag]( } /** - * Returns the top K (largest) elements from this RDD as defined by the specified + * Returns the top k (largest) elements from this RDD as defined by the specified * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) @@ -1106,14 +1106,14 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements */ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse) /** - * Returns the first K (smallest) elements from this RDD as defined by the specified + * Returns the first k (smallest) elements from this RDD as defined by the specified * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ @@ -1124,7 +1124,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements */ From f73b56f5e5d94f83d980475d3f39548986a92dd6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Nov 2014 18:16:20 -0800 Subject: [PATCH 47/68] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #464 (close requested by 'JoshRosen') Closes #283 (close requested by 'pwendell') Closes #449 (close requested by 'pwendell') Closes #907 (close requested by 'pwendell') Closes #2478 (close requested by 'JoshRosen') Closes #2192 (close requested by 'tdas') Closes #918 (close requested by 'pwendell') Closes #1465 (close requested by 'pwendell') Closes #3135 (close requested by 'JoshRosen') Closes #1693 (close requested by 'tdas') Closes #1279 (close requested by 'pwendell') From f8e5732307dcb1482d9bcf1162a1090ef9a7b913 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Nov 2014 22:11:20 -0800 Subject: [PATCH 48/68] SPARK-1209 [CORE] (Take 2) SparkHadoop{MapRed,MapReduce}Util should not use package org.apache.hadoop andrewor14 Another try at SPARK-1209, to address https://github.com/apache/spark/pull/2814#issuecomment-61197619 I successfully tested with `mvn -Dhadoop.version=1.0.4 -DskipTests clean package; mvn -Dhadoop.version=1.0.4 test` I assume that is what failed Jenkins last time. I also tried `-Dhadoop.version1.2.1` and `-Phadoop-2.4 -Pyarn -Phive` for more coverage. So this is why the class was put in `org.apache.hadoop` to begin with, I assume. One option is to leave this as-is for now and move it only when Hadoop 1.0.x support goes away. This is the other option, which adds a call to force the constructor to be public at run-time. It's probably less surprising than putting Spark code in `org.apache.hadoop`, but, does involve reflection. A `SecurityManager` might forbid this, but it would forbid a lot of stuff Spark does. This would also only affect Hadoop 1.0.x it seems. Author: Sean Owen Closes #3048 from srowen/SPARK-1209 and squashes the following commits: 0d48f4b [Sean Owen] For Hadoop 1.0.x, make certain constructors public, which were public in later versions 466e179 [Sean Owen] Disable MIMA warnings resulting from moving the class -- this was also part of the PairRDDFunctions type hierarchy though? eb61820 [Sean Owen] Move SparkHadoopMapRedUtil / SparkHadoopMapReduceUtil from org.apache.hadoop to org.apache.spark --- .../org/apache/spark/SparkHadoopWriter.scala | 1 + .../mapred/SparkHadoopMapRedUtil.scala | 17 +++++++++++++++-- .../mapreduce/SparkHadoopMapReduceUtil.scala | 5 +++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 1 + .../org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++- project/MimaExcludes.scala | 8 ++++++++ .../sql/parquet/ParquetTableOperations.scala | 1 + .../spark/sql/hive/hiveWriterContainers.scala | 1 + 8 files changed, 32 insertions(+), 5 deletions(-) rename core/src/main/scala/org/apache/{hadoop => spark}/mapred/SparkHadoopMapRedUtil.scala (79%) rename core/src/main/scala/org/apache/{hadoop => spark}/mapreduce/SparkHadoopMapReduceUtil.scala (96%) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 376e69cd997d5..40237596570de 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD /** diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala similarity index 79% rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 0c47afae54c8b..21b782edd2a9e 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -15,15 +15,24 @@ * limitations under the License. */ -package org.apache.hadoop.mapred +package org.apache.spark.mapred -private[apache] +import java.lang.reflect.Modifier + +import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} + +private[spark] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. + // Make it accessible if it's not in order to access it. + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } @@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil { val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) + // See above + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala similarity index 96% rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 1fca5729c6092..3340673f91156 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.hadoop.mapreduce +package org.apache.spark.mapreduce import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -private[apache] +private[spark] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 351e145f96f9a..e55d03d391e03 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index da89f634abaea..462f0d6268a86 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -33,13 +33,14 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} +RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6a0495f8fd540..a94d09be3bec6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,6 +77,14 @@ object MimaExcludes { // SPARK-3822 ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") ) case v if v.startsWith("1.1") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index d00860a8bb8a6..74c43e053b03c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -43,6 +43,7 @@ import parquet.hadoop.util.ContextUtil import parquet.io.ParquetDecodingException import parquet.schema.MessageType +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index bf2ce9df67c58..cc8bb3e172c6e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} From 3c2cff4b9464f8d7535564fcd194631a8e5bb0a5 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 9 Nov 2014 22:29:03 -0800 Subject: [PATCH 49/68] SPARK-3179. Add task OutputMetrics. Author: Sandy Ryza This patch had conflicts when merged, resolved by Committer: Kay Ousterhout Closes #2968 from sryza/sandy-spark-3179 and squashes the following commits: dce4784 [Sandy Ryza] More review feedback 8d350d1 [Sandy Ryza] Fix test against Hadoop 2.5+ e7c74d0 [Sandy Ryza] More review feedback 6cff9c4 [Sandy Ryza] Review feedback fb2dde0 [Sandy Ryza] SPARK-3179 --- .../apache/spark/deploy/SparkHadoopUtil.scala | 46 ++++++- .../apache/spark/executor/TaskMetrics.scala | 28 ++++ .../apache/spark/rdd/PairRDDFunctions.scala | 51 ++++++- .../apache/spark/scheduler/JobLogger.scala | 7 +- .../scala/org/apache/spark/ui/ToolTips.scala | 2 + .../apache/spark/ui/exec/ExecutorsTab.scala | 5 + .../apache/spark/ui/jobs/ExecutorTable.scala | 3 + .../spark/ui/jobs/JobProgressListener.scala | 6 + .../org/apache/spark/ui/jobs/StagePage.scala | 29 +++- .../org/apache/spark/ui/jobs/StageTable.scala | 4 + .../org/apache/spark/ui/jobs/UIData.scala | 2 + .../org/apache/spark/util/JsonProtocol.scala | 21 ++- ...te.scala => InputOutputMetricsSuite.scala} | 41 +++++- .../spark/scheduler/SparkListenerSuite.scala | 1 + .../ui/jobs/JobProgressListenerSuite.scala | 7 + .../apache/spark/util/JsonProtocolSuite.scala | 124 ++++++++++++++++-- 16 files changed, 346 insertions(+), 31 deletions(-) rename core/src/test/scala/org/apache/spark/metrics/{InputMetricsSuite.scala => InputOutputMetricsSuite.scala} (67%) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e28eaad8a5180..60ee115e393ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration @@ -133,14 +134,9 @@ class SparkHadoopUtil extends Logging { */ private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) : Option[() => Long] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) try { - val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) - val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") - val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead") + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() Some(() => f() - baselineBytesRead) @@ -151,6 +147,42 @@ class SparkHadoopUtil extends Logging { } } } + + /** + * Returns a function that can be called to find Hadoop FileSystem bytes written. If + * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes written on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + try { + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") + val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesWritten = f() + Some(() => f() - baselineBytesWritten) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) + None + } + } + } + + private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { + val qualifiedPath = path.getFileSystem(conf).makeQualified(path) + val scheme = qualifiedPath.toUri().getScheme() + val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + } + + private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { + val statisticsDataClass = + Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + statisticsDataClass.getDeclaredMethod(methodName) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 57bc2b40cec44..51b5328cb4c8f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -82,6 +82,12 @@ class TaskMetrics extends Serializable { */ var inputMetrics: Option[InputMetrics] = None + /** + * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much + * data was written are stored here. + */ + var outputMetrics: Option[OutputMetrics] = None + /** * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. * This includes read metrics aggregated over all the task's shuffle dependencies. @@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable { val Memory, Disk, Hadoop, Network = Value } +/** + * :: DeveloperApi :: + * Method by which output data was written. + */ +@DeveloperApi +object DataWriteMethod extends Enumeration with Serializable { + type DataWriteMethod = Value + val Hadoop = Value +} + /** * :: DeveloperApi :: * Metrics about reading input data. @@ -169,6 +185,18 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { var bytesRead: Long = 0L } +/** + * :: DeveloperApi :: + * Metrics about writing output data. + */ +@DeveloperApi +case class OutputMetrics(writeMethod: DataWriteMethod.Value) { + /** + * Total bytes written + */ + var bytesWritten: Long = 0L +} + /** * :: DeveloperApi :: * Metrics pertaining to shuffle data read in a given task. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 462f0d6268a86..8c2c959e73bb6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -28,7 +28,7 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -40,6 +40,7 @@ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer @@ -962,30 +963,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val hadoopContext = newTaskAttemptContext(config, attemptId) val format = outfmt.newInstance format match { - case c: Configurable => c.setConf(wrappedConf.value) + case c: Configurable => c.setConf(config) case _ => () } val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) + + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { + var recordsWritten = 0L while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close(hadoopContext) } committer.commitTask(hadoopContext) + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } 1 } : Int @@ -1006,6 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf) { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf + val wrappedConf = new SerializableWritable(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1033,27 +1045,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { + var recordsWritten = 0L while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close() } writer.commit() + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } } self.context.runJob(self, writeToFile) writer.commitJob() } + private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) + : (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) + .map(new Path(_)) + .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) + if (bytesWrittenCallback.isDefined) { + context.taskMetrics.outputMetrics = Some(outputMetrics) + } + (outputMetrics, bytesWrittenCallback) + } + + private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], + outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 + && bytesWrittenCallback.isDefined) { + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + } + } + /** * Return an RDD with the keys of each tuple. */ @@ -1070,3 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) } + +private[spark] object PairRDDFunctions { + val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 4e3d9de540783..3bb54855bae44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " INPUT_BYTES=" + metrics.bytesRead case None => "" } + val outputMetrics = taskMetrics.outputMetrics match { + case Some(metrics) => + " OUTPUT_BYTES=" + metrics.bytesWritten + case None => "" + } val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { case Some(metrics) => " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + @@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + shuffleReadMetrics + writeMetrics) } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 51dc08f668a43..6f446c5a95a0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -29,6 +29,8 @@ private[spark] object ToolTips { val INPUT = "Bytes read from Hadoop or from Spark storage." + val OUTPUT = "Bytes written to Hadoop." + val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index ba97630f025c1..dd1c2b78c4094 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -48,6 +48,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToOutputBytes = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() @@ -78,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead } + metrics.outputMetrics.foreach { outputMetrics => + executorToOutputBytes(eid) = + executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index f0e43fbf70976..fa0f96bff34ff 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -45,6 +45,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr Failed Tasks Succeeded Tasks Input + Output Shuffle Read Shuffle Write Shuffle Spill (Memory) @@ -77,6 +78,8 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr {v.succeededTasks} {Utils.bytesToString(v.inputBytes)} + + {Utils.bytesToString(v.outputBytes)} {Utils.bytesToString(v.shuffleRead)} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index e3223403c17f4..8bbde51e1801c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -259,6 +259,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val outputBytesDelta = + (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) + stageData.outputBytes += outputBytesDelta + execSummary.outputBytes += outputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 250bddbe2f262..16bc3f6c18d09 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -57,6 +57,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasAccumulators = accumulables.size > 0 val hasInput = stageData.inputBytes > 0 + val hasOutput = stageData.outputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 @@ -74,6 +75,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {Utils.bytesToString(stageData.inputBytes)} }} + {if (hasOutput) { +
  • + Output: + {Utils.bytesToString(stageData.outputBytes)} +
  • + }} {if (hasShuffleRead) {
  • Shuffle read: @@ -162,6 +169,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input", "")) else Nil} ++ + {if (hasOutput) Seq(("Output", "")) else Nil} ++ {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) @@ -172,7 +180,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val taskTable = UIUtils.listingTable( unzipped._1, - taskRow(hasAccumulators, hasInput, hasShuffleRead, hasShuffleWrite, hasBytesSpilled), + taskRow(hasAccumulators, hasInput, hasOutput, hasShuffleRead, hasShuffleWrite, + hasBytesSpilled), tasks, headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics @@ -260,6 +269,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } val inputQuantiles = Input +: getFormattedSizeQuantiles(inputSizes) + val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + } + val outputQuantiles = Output +: getFormattedSizeQuantiles(outputSizes) + val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } @@ -296,6 +310,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { , {gettingResultQuantiles}, if (hasInput) {inputQuantiles} else Nil, + if (hasOutput) {outputQuantiles} else Nil, if (hasShuffleRead) {shuffleReadQuantiles} else Nil, if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, @@ -328,6 +343,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def taskRow( hasAccumulators: Boolean, hasInput: Boolean, + hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = { @@ -351,6 +367,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -417,6 +439,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {inputReadable} }} + {if (hasOutput) { + + {outputReadable} + + }} {if (hasShuffleRead) { {shuffleReadReadable} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 3b4866e05956d..eae542df85d08 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -45,6 +45,7 @@ private[ui] class StageTableBase( Duration Tasks: Succeeded/Total Input + Output Shuffle Read + commons-io + commons-io + + + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + test-jar + test + com.twitter algebird-core_${scala.binary.version} From 974d334cf06a84317234a6c8e2e9ecca8271fa41 Mon Sep 17 00:00:00 2001 From: Varadharajan Mukundan Date: Mon, 10 Nov 2014 14:32:29 -0800 Subject: [PATCH 59/68] [SPARK-4047] - Generate runtime warnings for example implementation of PageRank Based on SPARK-2434, this PR generates runtime warnings for example implementations (Python, Scala) of PageRank. Author: Varadharajan Mukundan Closes #2894 from varadharajan/SPARK-4047 and squashes the following commits: 5f9406b [Varadharajan Mukundan] [SPARK-4047] - Point users to LogisticRegressionWithSGD and LogisticRegressionWithLBFGS instead of LogisticRegressionModel 252f595 [Varadharajan Mukundan] a. Generate runtime warnings for 05a018b [Varadharajan Mukundan] Fix PageRank implementation's package reference 5c2bf54 [Varadharajan Mukundan] [SPARK-4047] - Generate runtime warnings for example implementation of PageRank --- .../org/apache/spark/examples/JavaHdfsLR.java | 15 +++++++++++++++ .../org/apache/spark/examples/JavaPageRank.java | 13 +++++++++++++ examples/src/main/python/pagerank.py | 8 ++++++++ .../org/apache/spark/examples/LocalFileLR.scala | 6 ++++-- .../org/apache/spark/examples/LocalLR.scala | 6 ++++-- .../org/apache/spark/examples/SparkHdfsLR.scala | 6 ++++-- .../org/apache/spark/examples/SparkLR.scala | 6 ++++-- .../apache/spark/examples/SparkPageRank.scala | 15 +++++++++++++++ .../spark/examples/SparkTachyonHdfsLR.scala | 16 ++++++++++++++++ 9 files changed, 83 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index 6c177de359b60..31a79ddd3fff1 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -30,12 +30,25 @@ /** * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ public final class JavaHdfsLR { private static final int D = 10; // Number of dimensions private static final Random rand = new Random(42); + static void showWarning() { + String warning = "WARN: This is a naive implementation of Logistic Regression " + + "and is given as an example!\n" + + "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + + "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + + "for more conventional use."; + System.err.println(warning); + } + static class DataPoint implements Serializable { DataPoint(double[] x, double y) { this.x = x; @@ -109,6 +122,8 @@ public static void main(String[] args) { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(args[0]); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index c22506491fbff..a5db8accdf138 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -45,10 +45,21 @@ * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ public final class JavaPageRank { private static final Pattern SPACES = Pattern.compile("\\s+"); + static void showWarning() { + String warning = "WARN: This is a naive implementation of PageRank " + + "and is given as an example! \n" + + "Please use the PageRank implementation found in " + + "org.apache.spark.graphx.lib.PageRank for more conventional use."; + System.err.println(warning); + } + private static class Sum implements Function2 { @Override public Double call(Double a, Double b) { @@ -62,6 +73,8 @@ public static void main(String[] args) throws Exception { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index b539c4128cdcc..a5f25d78c1146 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +This is an example implementation of PageRank. For more conventional use, +Please refer to PageRank implementation provided by graphx +""" + import re import sys from operator import add @@ -40,6 +45,9 @@ def parseNeighbors(urls): print >> sys.stderr, "Usage: pagerank " exit(-1) + print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""" + # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 931faac5463c4..ac2ea35bbd0e0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalFileLR { val D = 10 // Numer of dimensions @@ -41,7 +42,8 @@ object LocalFileLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 2d75b9d2590f8..92a683ad57ea1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalLR { val N = 10000 // Number of data points @@ -48,7 +49,8 @@ object LocalLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 3258510894372..9099c2fcc90b3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -32,7 +32,8 @@ import org.apache.spark.scheduler.InputFormatInfo * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkHdfsLR { val D = 10 // Numer of dimensions @@ -54,7 +55,8 @@ object SparkHdfsLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index fc23308fc4adf..257a7d29f922a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -30,7 +30,8 @@ import org.apache.spark._ * Usage: SparkLR [slices] * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkLR { val N = 10000 // Number of data points @@ -53,7 +54,8 @@ object SparkLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 4c7e006da0618..8d092b6506d33 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -28,13 +28,28 @@ import org.apache.spark.{SparkConf, SparkContext} * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ object SparkPageRank { + + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of PageRank and is given as an example! + |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { if (args.length < 1) { System.err.println("Usage: SparkPageRank ") System.exit(1) } + + showWarning() + val sparkConf = new SparkConf().setAppName("PageRank") val iters = if (args.length > 0) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 96d13612e46dd..4393b99e636b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -32,11 +32,24 @@ import org.apache.spark.storage.StorageLevel /** * Logistic regression based classification. * This example uses Tachyon to persist rdds during computation. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkTachyonHdfsLR { val D = 10 // Numer of dimensions val rand = new Random(42) + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |for more conventional use. + """.stripMargin) + } + case class DataPoint(x: Vector[Double], y: Double) def parsePoint(line: String): DataPoint = { @@ -51,6 +64,9 @@ object SparkTachyonHdfsLR { } def main(args: Array[String]) { + + showWarning() + val inputPath = args(0) val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") val conf = new Configuration() From 6e7a309b814291d5936c2b5a7b22151b30ea2614 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 10 Nov 2014 14:56:06 -0800 Subject: [PATCH 60/68] Revert "[SPARK-2703][Core]Make Tachyon related unit tests execute without deploying a Tachyon system locally." This reverts commit bd86cb1738800a0aa4c88b9afdba2f97ac6cbf25. --- core/pom.xml | 7 ------- .../org/apache/spark/storage/BlockManagerSuite.scala | 11 ++--------- project/SparkBuild.scala | 2 -- 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 92e9f1fc46275..41296e0eca330 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -204,13 +204,6 @@ derby test - - org.tachyonproject - tachyon - 0.5.0 - test-jar - test - org.tachyonproject tachyon-client diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 86503c9a02058..9529502bc8e10 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -36,7 +36,6 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import tachyon.master.LocalTachyonCluster import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod @@ -537,14 +536,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("tachyon storage") { - val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", true) + // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar. + val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) if (tachyonUnitTestEnabled) { - val tachyonCluster = new LocalTachyonCluster(30000000) - tachyonCluster.start() - val tachyonURL = tachyon.Constants.HEADER + - tachyonCluster.getMasterHostname() + ":" + tachyonCluster.getMasterPort() - conf.set("spark.tachyonStore.url", tachyonURL) - conf.set("spark.tachyonStore.folderName", "app-test") store = makeBlockManager(1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -555,7 +549,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(store.getSingle("a3").isDefined, "a3 was in store") assert(store.getSingle("a2").isDefined, "a2 was in store") assert(store.getSingle("a1").isDefined, "a1 was in store") - tachyonCluster.stop() } else { info("tachyon storage test disabled.") } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 351e57a4b578b..657e4b4432775 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -360,8 +360,6 @@ object TestSettings { testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", - // Enable Tachyon local testing. - libraryDependencies += "org.tachyonproject" % "tachyon" % "0.5.0" % "test" classifier "tests", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), From dbf10588de03e8ea993fff687a78727eff55db1f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 10 Nov 2014 15:55:15 -0800 Subject: [PATCH 61/68] [SPARK-4319][SQL] Enable an ignored test "null count". Author: Takuya UESHIN Closes #3185 from ueshin/issues/SPARK-4319 and squashes the following commits: a44a38e [Takuya UESHIN] Enable an ignored test "null count". --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 ++++----- .../src/test/scala/org/apache/spark/sql/TestData.scala | 9 +++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 702714af5308d..8a80724c08c7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -281,14 +281,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 3) } - // No support for primitive nulls yet. - ignore("null count") { + test("null count") { checkAnswer( - sql("SELECT a, COUNT(b) FROM testData3"), - Seq((1,0), (2, 1))) + sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), + Seq((1, 0), (2, 1))) checkAnswer( - testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), (2, 1, 2, 2, 1) :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index ef87a230639bc..92b49e8155900 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -64,11 +64,12 @@ object TestData { BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD binaryData.registerTempTable("binaryData") - // TODO: There is no way to express null primitives as case classes currently... + case class TestData3(a: Int, b: Option[Int]) val testData3 = - logical.LocalRelation('a.int, 'b.int).loadData( - (1, null) :: - (2, 2) :: Nil) + TestSQLContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toSchemaRDD + testData3.registerTempTable("testData3") val emptyTableData = logical.LocalRelation('a.int, 'b.int) From 534b23141715b69a89531d93d4b9b78cf2789ff4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 10 Nov 2014 16:17:52 -0800 Subject: [PATCH 62/68] [SPARK-4000][Build] Uploads HiveCompatibilitySuite logs This is a follow up of #2845. In addition to unit-tests.log files, also upload failure output files generated by `HiveCompatibilitySuite` to Jenkins master. These files can be very helpful to debug Hive compatibility test failures. /cc pwendell marmbrus Author: Cheng Lian Closes #2993 from liancheng/upload-hive-compat-logs and squashes the following commits: 8e6247f [Cheng Lian] Uploads HiveCompatibilitySuite logs --- dev/run-tests-jenkins | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 87c6715153da7..6a849e4f77207 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -101,7 +101,13 @@ function post_message () { function send_archived_logs () { echo "Archiving unit tests logs..." - local log_files=$(find . -name "unit-tests.log") + local log_files=$( + find .\ + -name "unit-tests.log" -o\ + -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ + -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ + -path "./sql/hive/target/HiveCompatibilitySuite.wrong" + ) if [ -z "$log_files" ]; then echo "> No log files found." >&2 From acb55aeddbe58758d75b9aed130634afe21797cf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 10 Nov 2014 16:56:36 -0800 Subject: [PATCH 63/68] [SPARK-4308][SQL] Sets SQL operation state to ERROR when exception is thrown In `HiveThriftServer2`, when an exception is thrown during a SQL execution, the SQL operation state should be set to `ERROR`, but now it remains `RUNNING`. This affects the result of the `GetOperationStatus` Thrift API. Author: Cheng Lian Closes #3175 from liancheng/fix-op-state and squashes the following commits: 6d4c1fe [Cheng Lian] Sets SQL operation state to ERROR when exception is thrown --- .../thriftserver/AbstractSparkSQLDriver.scala | 2 -- .../spark/sql/hive/thriftserver/Shim12.scala | 12 +++---- .../spark/sql/hive/thriftserver/Shim13.scala | 36 ++++++++----------- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index fcb302edbffa8..6ed8fd2768f95 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ -import java.util.{ArrayList => JArrayList} - import org.apache.commons.lang.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index e3ba9914c6cc0..aa2e3cab72bb9 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -25,9 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.math._ import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ @@ -37,9 +35,9 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.SetCommand import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} -import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} /** * A compatibility layer for interacting with Hive version 0.12.0. @@ -71,8 +69,9 @@ private[hive] class SparkExecuteStatementOperation( statement: String, confOverlay: JMap[String, String])( hiveContext: HiveContext, - sessionToActivePool: SMap[HiveSession, String]) extends ExecuteStatementOperation( - parentSession, statement, confOverlay) with Logging { + sessionToActivePool: SMap[HiveSession, String]) + extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { + private var result: SchemaRDD = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -216,6 +215,7 @@ private[hive] class SparkExecuteStatementOperation( // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => + setState(OperationState.ERROR) logError("Error executing query:",e) throw new HiveSQLException(e.toString) } diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index f2ceba828296b..a642478d08857 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -27,10 +27,9 @@ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.math._ import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ @@ -39,9 +38,9 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.{Row => SparkRow, SchemaRDD} -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} /** * A compatibility layer for interacting with Hive version 0.12.0. @@ -100,6 +99,7 @@ private[hive] class SparkExecuteStatementOperation( // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => + setState(OperationState.ERROR) logError("Error executing query:",e) throw new HiveSQLException(e.toString) } @@ -194,14 +194,12 @@ private[hive] class SparkExecuteStatementOperation( try { sqlOperationConf.verifyAndSet(confEntry.getKey, confEntry.getValue) } - catch { - case e: IllegalArgumentException => { - throw new HiveSQLException("Error applying statement specific settings", e) - } + catch { case e: IllegalArgumentException => + throw new HiveSQLException("Error applying statement specific settings", e) } } } - return sqlOperationConf + sqlOperationConf } def run(): Unit = { @@ -219,7 +217,7 @@ private[hive] class SparkExecuteStatementOperation( val currentUGI: UserGroupInformation = ShimLoader.getHadoopShims.getUGIForConf(opConfig) val backgroundOperation: Runnable = new Runnable { - def run { + def run() { val doAsAction: PrivilegedExceptionAction[AnyRef] = new PrivilegedExceptionAction[AnyRef] { def run: AnyRef = { @@ -228,23 +226,19 @@ private[hive] class SparkExecuteStatementOperation( try { runInternal(statement) } - catch { - case e: HiveSQLException => { - setOperationException(e) - logError("Error running hive query: ", e) - } + catch { case e: HiveSQLException => + setOperationException(e) + logError("Error running hive query: ", e) } - return null + null } } try { ShimLoader.getHadoopShims.doAs(currentUGI, doAsAction) } - catch { - case e: Exception => { - setOperationException(new HiveSQLException(e)) - logError("Error running hive query as user : " + currentUGI.getShortUserName, e) - } + catch { case e: Exception => + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + currentUGI.getShortUserName, e) } setState(OperationState.FINISHED) } From d793d80c8084923ea04dcf7d268eec8ede490127 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 10 Nov 2014 17:20:52 -0800 Subject: [PATCH 64/68] [SQL] remove a decimal case branch that has no effect at runtime it generates warnings at compile time marmbrus Author: Xiangrui Meng Closes #3192 from mengxr/dtc-decimal and squashes the following commits: 955e9fb [Xiangrui Meng] remove a decimal case branch that has no effect --- .../org/apache/spark/sql/types/util/DataTypeConversions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 3fa4a7c6481d3..9aad7b3df4eed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -133,7 +133,6 @@ protected[sql] object DataTypeConversions { def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) - case (d: java.math.BigDecimal, _) => BigDecimal(d) case (other, _) => other } From fa777833b52b6f339cdc335e8e3935cfe9a2a7eb Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 10 Nov 2014 17:22:57 -0800 Subject: [PATCH 65/68] [SPARK-4250] [SQL] Fix bug of constant null value mapping to ConstantObjectInspector Author: Cheng Hao Closes #3114 from chenghao-intel/constant_null_oi and squashes the following commits: e603bda [Cheng Hao] fix the bug of null value for primitive types 50a13ba [Cheng Hao] fix the timezone issue f54f369 [Cheng Hao] fix bug of constant null value for ObjectInspector --- .../spark/sql/hive/HiveInspectors.scala | 78 ++++++++++-------- ...testing-0-9a02bc7de09bcabcbd4c91f54a814c20 | 1 + .../udf_if-0-b7ffa85b5785cccef2af1b285348cc2c | 1 + .../udf_if-1-30cf7f51f92b5684e556deff3032d49a | 1 + .../udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 | 0 .../udf_if-3-20206f17367ff284d67044abd745ce9f | 1 + .../udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca | 0 .../udf_if-5-a7db13aec05c97792f9331d63709d8cc | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 52 +++++++++++- .../org/apache/spark/sql/hive/Shim12.scala | 70 ++++++++++------ .../org/apache/spark/sql/hive/Shim13.scala | 80 +++++++++++++------ 11 files changed, 199 insertions(+), 86 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 create mode 100644 sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c create mode 100644 sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a create mode 100644 sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 create mode 100644 sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f create mode 100644 sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca create mode 100644 sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index bdc7e1dac1922..7e76aff642bb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -88,6 +88,7 @@ private[hive] trait HiveInspectors { * @return convert the data into catalyst type */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { + case _ if data == null => null case hvoi: HiveVarcharObjectInspector => if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue case hdoi: HiveDecimalObjectInspector => @@ -250,46 +251,53 @@ private[hive] trait HiveInspectors { } def toInspector(expr: Expression): ObjectInspector = expr match { - case Literal(value: String, StringType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Int, IntegerType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Double, DoubleType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Boolean, BooleanType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Long, LongType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Float, FloatType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Short, ShortType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Byte, ByteType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Array[Byte], BinaryType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: java.sql.Date, DateType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: java.sql.Timestamp, TimestampType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: BigDecimal, DecimalType()) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Decimal, DecimalType()) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal) + case Literal(value, StringType) => + HiveShim.getStringWritableConstantObjectInspector(value) + case Literal(value, IntegerType) => + HiveShim.getIntWritableConstantObjectInspector(value) + case Literal(value, DoubleType) => + HiveShim.getDoubleWritableConstantObjectInspector(value) + case Literal(value, BooleanType) => + HiveShim.getBooleanWritableConstantObjectInspector(value) + case Literal(value, LongType) => + HiveShim.getLongWritableConstantObjectInspector(value) + case Literal(value, FloatType) => + HiveShim.getFloatWritableConstantObjectInspector(value) + case Literal(value, ShortType) => + HiveShim.getShortWritableConstantObjectInspector(value) + case Literal(value, ByteType) => + HiveShim.getByteWritableConstantObjectInspector(value) + case Literal(value, BinaryType) => + HiveShim.getBinaryWritableConstantObjectInspector(value) + case Literal(value, DateType) => + HiveShim.getDateWritableConstantObjectInspector(value) + case Literal(value, TimestampType) => + HiveShim.getTimestampWritableConstantObjectInspector(value) + case Literal(value, DecimalType()) => + HiveShim.getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => HiveShim.getPrimitiveNullWritableConstantObjectInspector - case Literal(value: Seq[_], ArrayType(dt, _)) => + case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) - val list = new java.util.ArrayList[Object]() - value.foreach(v => list.add(wrap(v, listObjectInspector))) - ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) - case Literal(map: Map[_, _], MapType(keyType, valueType, _)) => - val value = new java.util.HashMap[Object, Object]() + if (value == null) { + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) + } else { + val list = new java.util.ArrayList[Object]() + value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) + } + case Literal(value, MapType(keyType, valueType, _)) => val keyOI = toInspector(keyType) val valueOI = toInspector(valueType) - map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))) - ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value) - case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") + if (value == null) { + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null) + } else { + val map = new java.util.HashMap[Object, Object]() + value.asInstanceOf[Map[_, _]].foreach (entry => { + map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + }) + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) + } case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 new file mode 100644 index 0000000000000..7c41615f8c184 --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 @@ -0,0 +1 @@ +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c new file mode 100644 index 0000000000000..2cf0d9d61882e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -0,0 +1 @@ +There is no documentation for function 'if' diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a new file mode 100644 index 0000000000000..2cf0d9d61882e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -0,0 +1 @@ +There is no documentation for function 'if' diff --git a/sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 b/sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f b/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f new file mode 100644 index 0000000000000..a29e96cbd1db7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f @@ -0,0 +1 @@ +1 1 1 1 NULL 2 diff --git a/sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca b/sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc b/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc new file mode 100644 index 0000000000000..f0669b86989d0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc @@ -0,0 +1 @@ +128 1.1 ABC 12.3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b897dff0159ff..684d22807c0c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.util.{Locale, TimeZone} + +import org.scalatest.BeforeAndAfter import scala.util.Try @@ -28,14 +31,59 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest { +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + override def beforeAll() { + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + } + + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + } + + createQueryTest("constant null testing", + """SELECT + |IF(FALSE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL1, + |IF(TRUE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL2, + |IF(FALSE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL3, + |IF(TRUE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL4, + |IF(FALSE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL5, + |IF(TRUE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL6, + |IF(FALSE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL7, + |IF(TRUE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL8, + |IF(FALSE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL9, + |IF(TRUE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL10, + |IF(FALSE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL11, + |IF(TRUE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL12, + |IF(FALSE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL13, + |IF(TRUE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL14, + |IF(FALSE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL15, + |IF(TRUE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL16, + |IF(FALSE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL17, + |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, + |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, + |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, + |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |FROM src LIMIT 1""".stripMargin) + createQueryTest("constant array", """ |SELECT sort_array( diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 8e946b7e82f5d..8ba25f889d176 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -57,54 +57,74 @@ private[hive] object HiveShim { new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) } - def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, new hadoopIo.Text(value)) + PrimitiveCategory.STRING, + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) - def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, new hadoopIo.IntWritable(value)) + PrimitiveCategory.INT, + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) - def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value)) + PrimitiveCategory.DOUBLE, + if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])) - def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value)) + PrimitiveCategory.BOOLEAN, + if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])) - def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, new hadoopIo.LongWritable(value)) + PrimitiveCategory.LONG, + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) - def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value)) + PrimitiveCategory.FLOAT, + if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])) - def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value)) + PrimitiveCategory.SHORT, + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) - def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value)) + PrimitiveCategory.BYTE, + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) - def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value)) + PrimitiveCategory.BINARY, + if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, new hiveIo.DateWritable(value)) + PrimitiveCategory.DATE, + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value)) - - def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + PrimitiveCategory.TIMESTAMP, + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + }) + + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DECIMAL, - new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + if (value == null) { + null + } else { + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + }) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 0bc330cdbecb1..e4aee57f0ad9f 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -56,54 +56,86 @@ private[hive] object HiveShim { new TableDesc(inputFormatClass, outputFormatClass, properties) } - def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, new hadoopIo.Text(value)) + TypeInfoFactory.stringTypeInfo, + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) - def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, new hadoopIo.IntWritable(value)) + TypeInfoFactory.intTypeInfo, + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) - def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, new hiveIo.DoubleWritable(value)) + TypeInfoFactory.doubleTypeInfo, if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, new hadoopIo.BooleanWritable(value)) + TypeInfoFactory.booleanTypeInfo, if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, new hadoopIo.LongWritable(value)) + TypeInfoFactory.longTypeInfo, + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) - def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, new hadoopIo.FloatWritable(value)) + TypeInfoFactory.floatTypeInfo, if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, new hiveIo.ShortWritable(value)) + TypeInfoFactory.shortTypeInfo, + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) - def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, new hiveIo.ByteWritable(value)) + TypeInfoFactory.byteTypeInfo, + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) - def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, new hadoopIo.BytesWritable(value)) + TypeInfoFactory.binaryTypeInfo, if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + }) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, new hiveIo.DateWritable(value)) + TypeInfoFactory.dateTypeInfo, + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, new hiveIo.TimestampWritable(value)) + TypeInfoFactory.timestampTypeInfo, if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + }) - def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.decimalTypeInfo, - new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + }) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( From a1fc059b69c9ed150bf8a284404cc149ddaa27d6 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 10 Nov 2014 17:26:03 -0800 Subject: [PATCH 66/68] [SPARK-4149][SQL] ISO 8601 support for json date time strings This implement the feature davies mentioned in https://github.com/apache/spark/pull/2901#discussion-diff-19313312 Author: Daoyuan Wang Closes #3012 from adrian-wang/iso8601 and squashes the following commits: 50df6e7 [Daoyuan Wang] json data timestamp ISO8601 support --- .../org/apache/spark/sql/json/JsonRDD.scala | 5 ++-- .../sql/types/util/DataTypeConversions.scala | 30 +++++++++++++++++++ .../org/apache/spark/sql/json/JsonSuite.scala | 7 +++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0f2dcdcacf0ca..d9d7a3fea3963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types.util.DataTypeConversions import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} @@ -378,7 +379,7 @@ private[sql] object JsonRDD extends Logging { private def toDate(value: Any): Date = { value match { // only support string as date - case value: java.lang.String => Date.valueOf(value) + case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) } } @@ -386,7 +387,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => Timestamp.valueOf(value) + case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 9aad7b3df4eed..d4258156f18f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types.util +import java.text.SimpleDateFormat + import scala.collection.JavaConverters._ import org.apache.spark.sql._ @@ -129,6 +131,34 @@ protected[sql] object DataTypeConversions { StructType(structType.getFields.map(asScalaStructField)) } + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + /** Converts Java objects to catalyst rows / types */ def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index cade244f7ac39..f8ca2c773d9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -66,6 +66,13 @@ class JsonSuite extends QueryTest { val strDate = "2014-10-15" checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) + + val ISO8601Time1 = "1970-01-01T01:00:01.0Z" + checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) + val ISO8601Time2 = "1970-01-01T02:00:01-01:00" + checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { From ce6ed2abd14de26b9ceaa415e9a42fbb1338f5fa Mon Sep 17 00:00:00 2001 From: surq Date: Mon, 10 Nov 2014 17:37:16 -0800 Subject: [PATCH 67/68] [SPARK-3954][Streaming] Optimization to FileInputDStream about convert files to RDDS there are 3 loops with files sequence in spark source. loops files sequence: 1.files.map(...) 2.files.zip(fileRDDs) 3.files-size.foreach It's will very time consuming when lots of files.So I do the following correction: 3 loops with files sequence => only one loop Author: surq Closes #2811 from surq/SPARK-3954 and squashes the following commits: 321bbe8 [surq] updated the code style.The style from [for...yield]to [files.map(file=>{})] 88a2c20 [surq] Merge branch 'master' of https://github.com/apache/spark into SPARK-3954 178066f [surq] modify code's style. [Exceeds 100 columns] 626ef97 [surq] remove redundant import(ArrayBuffer) 739341f [surq] promote the speed of convert files to RDDS --- .../apache/spark/streaming/dstream/FileInputDStream.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 8152b7542ac57..55d6cf6a783ea 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -120,14 +120,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) - files.zip(fileRDDs).foreach { case (file, rdd) => { + val fileRDDs = files.map(file =>{ + val rdd = context.sparkContext.newAPIHadoopFile[K, V, F](file) if (rdd.partitions.size == 0) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + "Refer to the streaming programming guide for more details.") } - }} + rdd + }) new UnionRDD(context.sparkContext, fileRDDs) } From c764d0ac1c6410ca2dd2558cb6bcbe8ad5f02481 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 10 Nov 2014 17:46:05 -0800 Subject: [PATCH 68/68] [SPARK-4274] [SQL] Fix NPE in printing the details of the query plan Author: Cheng Hao Closes #3139 from chenghao-intel/comparison_test and squashes the following commits: f5d7146 [Cheng Hao] avoid exception in printing the codegen enabled --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 84eaf401f240c..31cc4170aa867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -444,7 +444,7 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} - |Code Generation: ${executedPlan.codegenEnabled} + |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} |== RDD == """.stripMargin.trim }