From 8191bcb7d5b1d38985096d4ace402fab10148bdf Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 3 Nov 2014 16:53:38 -0800 Subject: [PATCH 01/10] [SPARK-2938] Support SASL authentication in NettyBlockTransferService Also lays the groundwork for supporting it inside the external shuffle service. --- .../org/apache/spark/SecurityManager.scala | 15 +- .../scala/org/apache/spark/SparkConf.scala | 6 + .../scala/org/apache/spark/SparkContext.scala | 2 + .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../org/apache/spark/SparkSaslClient.scala | 147 --------------- .../org/apache/spark/SparkSaslServer.scala | 176 ------------------ .../org/apache/spark/executor/Executor.scala | 1 + .../netty/NettyBlockTransferService.scala | 25 ++- .../apache/spark/network/nio/Connection.scala | 5 +- .../spark/network/nio/ConnectionManager.scala | 7 +- .../apache/spark/storage/BlockManager.scala | 46 +++-- .../BlockManagerReplicationSuite.scala | 2 + .../spark/storage/BlockManagerSuite.scala | 4 +- .../spark/network/TransportContext.java | 15 +- .../spark/network/client/TransportClient.java | 5 + .../client/TransportClientBootstrap.java | 22 +++ .../client/TransportClientFactory.java | 45 +++-- .../spark/network/server/RpcHandler.java | 2 + .../spark/network/sasl/SaslBootstrap.java | 54 ++++++ .../spark/network/sasl/SaslRpcHandler.java | 131 +++++++++++++ .../spark/network/sasl/SecretKeyHolder.java | 35 ++++ .../spark/network/sasl/SparkSaslClient.java | 138 ++++++++++++++ .../spark/network/sasl/SparkSaslServer.java | 172 +++++++++++++++++ .../shuffle/ExternalShuffleClient.java | 15 +- .../spark/network/shuffle/ShuffleClient.java | 11 +- .../ExternalShuffleIntegrationSuite.java | 6 +- .../streaming/ReceivedBlockHandlerSuite.scala | 1 + 27 files changed, 717 insertions(+), 374 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslServer.scala create mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377e..a77e8e788d730 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -139,7 +140,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +338,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + override def getSaslUser(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSaslUser() + } + + override def getSecretKey(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSecretKey() + } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ad0a9017afead..a205d88d1877e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration in the + * driver and from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b4db783979ec..d65027d18e2d4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e2f13accdfab5..45e9d7f243e96 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -276,7 +276,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf) + new NettyBlockTransferService(conf, securityManager) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -285,6 +285,7 @@ object SparkEnv extends Logging { "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index a954fcc0c31fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index 7c2afb364661f..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), UTF_8) - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), UTF_8).toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e24a15f015e1c..7dd5265891c3e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -86,6 +86,7 @@ private[spark] class Executor( conf, executorId, slaveHostname, port, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 1c4327cf13b51..70b04f39400d2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,13 +17,15 @@ package org.apache.spark.network.netty +import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} +import org.apache.spark.network.sasl.{SecretKeyHolder, SaslRpcHandler, SaslBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer @@ -33,18 +35,29 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService { + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - val serializer = new JavaSerializer(conf) + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslBootstrap(conf.getAppId, securityManager))) + } + } transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) - clientFactory = transportContext.createClientFactory() + clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() logInfo("Server created on " + server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 4f6f5e235811d..c2d9578be7ebb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,12 +23,13 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} + private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8408b75bb4d65..f198aa8564a54 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils import scala.util.Try @@ -600,7 +601,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -634,7 +635,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -778,7 +779,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5f5dd0dc1c63f..02cc9b3924842 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -57,6 +57,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -69,8 +75,6 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -102,22 +106,17 @@ private[spark] class BlockManager( + " switch to sort-based shuffle.") } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + var blockManagerId: BlockManagerId = _ +// BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port) // Address of the server that serves this executor's shuffle files. This is either an external // service, or just our own Executor's BlockManager. - private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { - BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) - } else { - blockManagerId - } + private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val appId = conf.get("spark.app.id", "unknown-app-id") - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) } else { blockTransferService } @@ -150,7 +149,7 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() +// initialize() /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark @@ -176,10 +175,27 @@ private[spark] class BlockManager( } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initialies the BlockTransferService and ShuffleClient registers with the + * BlockManagerMaster, starts theBlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) // Register Executors' configuration with the local shuffle service, if one should exist. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c6d7105592096..1461fa69db90d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + store.initialize("app-id") allStores += store store } @@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 715b740b857b2..0782876c8e3c6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + manager.initialize("app-id") + manager } before { diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index a271841e4e56c..5bc6e5a2418a9 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -64,8 +68,17 @@ public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.newArrayList()); } /** Create a server which will attempt to bind to a specific port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 01c143fff423c..93510dab0b138 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -186,4 +186,9 @@ public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + /** Returns a stable key for the given channel. Only valid after the channel is connected. */ + public String getChannelKey() { + return channel.toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000000..130b080db0f65 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface TransportClientBootstrap { + public void doBootstrap(TransportClient client); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 0b4a1d8286407..9897e63e2511b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -21,10 +21,12 @@ import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -47,8 +49,11 @@ * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); @@ -56,14 +61,18 @@ public class TransportClientFactory implements Closeable { private final TransportContext context; private final TransportConf conf; private final ConcurrentHashMap connectionPool; + private final List clientBootstraps; private final Class socketChannelClass; private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { this.context = context; this.conf = context.getConf(); this.connectionPool = new ConcurrentHashMap(); + this.clientBootstraps = Lists.newArrayList(clientBootstraps); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); @@ -72,9 +81,12 @@ public TransportClientFactory(TransportContext context) { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ @@ -104,13 +116,13 @@ public TransportClient createClient(String remoteHost, int remotePort) { // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference client = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); @@ -123,15 +135,24 @@ public void initChannel(SocketChannel ch) { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection -- in the event that two threads raced to create a client, we will + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + logger.debug("Connection to {} successful, running bootstraps...", address); + // Execute any client bootstraps synchronously before marking the Client as successful. + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + + logger.debug("Successfully executed {} bootstraps for {}", clientBootstraps.size(), address); + // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + return client; } else { logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + client.close(); return oldClient; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2369dc6203944..40b45a30fdf59 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -28,6 +28,8 @@ public interface RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. * @param message The serialized bytes of the RPC. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java new file mode 100644 index 0000000000000..2b9f1262972e3 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; + +public class SaslBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslBootstrap.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + + public SaslBootstrap(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + } + + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(secretKeyId, secretKeyHolder); + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(secretKeyId, payload); + logger.info("Sending msg {} {}", secretKeyId, payload.length); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), 300000); + logger.info("Got response {} {}", secretKeyId, response.length); + payload = saslClient.response(response); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000000..709d184034c94 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + */ +public class SaslRpcHandler implements RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + private final RpcHandler delegate; + private final SecretKeyHolder secretKeyHolder; + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + String channelKey = client.getChannelKey(); + + SparkSaslServer saslServer = channelAuthenticationMap.get(channelKey); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(channelKey, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", channelKey); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } +} + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given id. This 'id' allows a single SaslRpcHandler to multiplex different applications + * who may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + byte[] idBytes = appId.getBytes(Charsets.UTF_8); + buf.writeInt(idBytes.length); + buf.writeBytes(appId.getBytes(Charsets.UTF_8)); + buf.writeInt(payload.length); + buf.writeBytes(payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else"); + } + + int idLength = buf.readInt(); + byte[] idBytes = new byte[idLength]; + buf.readBytes(idBytes); + + int payloadLength = buf.readInt(); + byte[] payload = new byte[payloadLength]; + buf.readBytes(payload); + + return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000000..aaf50d2b62dd2 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some appId. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000000..72ba737b998bc --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + logger.info("Realm callback"); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000000..937370355dfa8 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + System.exit(0); + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + if (password != null) { + return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + } else { + return new char[0]; + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6bbabc44b958b..b0b19ba67bddc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.Closeable; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,15 +34,20 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient implements ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportClientFactory clientFactory; - private final String appId; - public ExternalShuffleClient(TransportConf conf, String appId) { + private String appId; + + public ExternalShuffleClient(TransportConf conf) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); this.clientFactory = context.createClientFactory(); + } + + @Override + public void init(String appId) { this.appId = appId; } @@ -55,6 +58,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener) { + assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { TransportClient client = clientFactory.createClient(host, port); @@ -82,6 +86,7 @@ public void registerWithShuffleServer( int port, String execId, ExecutorShuffleInfo executorInfo) { + assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index d46a562394557..f72ab40690d0d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -20,7 +20,14 @@ import java.io.Closeable; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public interface ShuffleClient extends Closeable { +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + /** * Fetch a sequence of blocks from a remote node asynchronously, * @@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - public void fetchBlocks( + public abstract void fetchBlocks( String host, int port, String execId, diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b3bcf5fd68e73..28113ca5d405b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,8 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override @@ -265,7 +266,8 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ad1a6f01b3a57..0f27f55fec4f3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -74,6 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr)) + blockManager.initialize("app-id") tempDirectory = Files.createTempDir() manualClock.setTime(0) From 7b42adb96832c4bf536efbc0066670ce5b47ebd2 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 3 Nov 2014 21:36:17 -0800 Subject: [PATCH 02/10] Add unit tests --- .../apache/spark/storage/BlockManager.scala | 5 +- .../network/nio/ConnectionManagerSuite.scala | 6 +- .../spark/network/client/TransportClient.java | 3 +- .../spark/network/sasl/SaslMessage.java | 74 ++++++++ .../spark/network/sasl/SaslRpcHandler.java | 50 +---- .../network/sasl/SaslIntegrationSuite.java | 172 ++++++++++++++++++ .../spark/network/sasl/SparkSaslSuite.java | 89 +++++++++ .../network/sasl/TestSecretKeyHolder.java | 18 ++ .../ExternalShuffleIntegrationSuite.java | 1 + 9 files changed, 362 insertions(+), 56 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 02cc9b3924842..414f2ec5afa4c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -107,7 +107,6 @@ private[spark] class BlockManager( } var blockManagerId: BlockManagerId = _ -// BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port) // Address of the server that serves this executor's shuffle files. This is either an external // service, or just our own Executor's BlockManager. @@ -149,8 +148,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L -// initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -179,7 +176,7 @@ private[spark] class BlockManager( * the appId may not be known at BlockManager instantiation time (in particular for the driver, * where it is only learned after registration with the TaskScheduler). * - * This method initialies the BlockTransferService and ShuffleClient registers with the + * This method initializes the BlockTransferService and ShuffleClient registers with the * BlockManagerMaster, starts theBlockManagerWorker actor, and registers with a local shuffle * service if configured. */ diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index b70734dfe37cf..716f875d30b8a 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") + conf.set("spark.app.id", "app-id") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) var numReceivedMessages = 0 @@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite { test("security mismatch password") { val conf = new SparkConf conf.set("spark.authenticate", "true") + conf.set("spark.app.id", "app-id") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) @@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite { None }) - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") + val badconf = conf.clone.set("spark.authenticate.secret", "bad") val badsecurityManager = new SecurityManager(badconf) val managerServer = new ConnectionManager(0, badconf, badsecurityManager) var numReceivedServerMessages = 0 diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 93510dab0b138..b670ab6dabeac 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -189,6 +189,7 @@ public void close() { /** Returns a stable key for the given channel. Only valid after the channel is connected. */ public String getChannelKey() { - return channel.toString(); + return String.format("[%s, %s, %s]", channel.remoteAddress(), channel.localAddress(), + channel.hashCode()); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..2a28ae0f63e58 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications who may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + // tag + appIdLength + appId + payloadLength + payload + return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + byte[] idBytes = appId.getBytes(Charsets.UTF_8); + buf.writeInt(idBytes.length); + buf.writeBytes(appId.getBytes(Charsets.UTF_8)); + buf.writeInt(payload.length); + buf.writeBytes(payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else"); + } + + int idLength = buf.readInt(); + byte[] idBytes = new byte[idLength]; + buf.readBytes(idBytes); + + int payloadLength = buf.readInt(); + byte[] payload = new byte[payloadLength]; + buf.readBytes(payload); + + return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 709d184034c94..2c5d58fe07f87 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -42,6 +42,8 @@ public class SaslRpcHandler implements RpcHandler { private final RpcHandler delegate; private final SecretKeyHolder secretKeyHolder; + + // TODO: Invalidate channels that have closed! private final ConcurrentMap channelAuthenticationMap; public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { @@ -81,51 +83,3 @@ public StreamManager getStreamManager() { } } -/** - * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged - * with the given id. This 'id' allows a single SaslRpcHandler to multiplex different applications - * who may be using different sets of credentials. - */ -class SaslMessage implements Encodable { - - private static final byte TAG_BYTE = (byte) 0xEA; - - public final String appId; - public final byte[] payload; - - public SaslMessage(String appId, byte[] payload) { - this.appId = appId; - this.payload = payload; - } - - @Override - public int encodedLength() { - return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - byte[] idBytes = appId.getBytes(Charsets.UTF_8); - buf.writeInt(idBytes.length); - buf.writeBytes(appId.getBytes(Charsets.UTF_8)); - buf.writeInt(payload.length); - buf.writeBytes(payload); - } - - public static SaslMessage decode(ByteBuf buf) { - if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else"); - } - - int idLength = buf.readInt(); - byte[] idBytes = new byte[idLength]; - buf.readBytes(idBytes); - - int payloadLength = buf.readInt(); - byte[] payload = new byte[payloadLength]; - buf.readBytes(payload); - - return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000000..c653dda5d20be --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslBootstrap("app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslBootstrap("app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslBootstrap("app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler implements RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000000..67a07f38eb5a0 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java new file mode 100644 index 0000000000000..e3a69c25cf60d --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java @@ -0,0 +1,18 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 28113ca5d405b..bc101f53844d5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -165,6 +165,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } + client.close(); return res; } From f6177d73f8a644b881e37ebdaf7728fe9657be1b Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 3 Nov 2014 21:57:18 -0800 Subject: [PATCH 03/10] Cleanup SASL state upon connection termination --- .../spark/network/client/TransportClient.java | 13 ++++---- .../spark/network/server/NoOpRpcHandler.java | 2 +- .../spark/network/server/RpcHandler.java | 17 +++++++--- .../server/TransportRequestHandler.java | 1 + .../spark/network/sasl/SaslBootstrap.java | 31 ++++++++++++------- .../spark/network/sasl/SaslRpcHandler.java | 24 +++++++++----- .../shuffle/ExternalShuffleBlockHandler.java | 2 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- 8 files changed, 61 insertions(+), 31 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index b670ab6dabeac..a08cee02dd576 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,10 +19,9 @@ import java.io.Closeable; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; @@ -187,9 +186,11 @@ public void close() { channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } - /** Returns a stable key for the given channel. Only valid after the channel is connected. */ - public String getChannelKey() { - return String.format("[%s, %s, %s]", channel.remoteAddress(), channel.localAddress(), - channel.hashCode()); + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 5a3f003726fc1..1502b7489e864 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -21,7 +21,7 @@ import org.apache.spark.network.client.TransportClient; /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler implements RpcHandler { +public class NoOpRpcHandler extends RpcHandler { private final StreamManager streamManager; public NoOpRpcHandler() { diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 40b45a30fdf59..2ba92a40f8b0a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,7 +23,7 @@ /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. @@ -31,16 +31,25 @@ public interface RpcHandler { * This method will not be called in parallel for a single TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. */ - StreamManager getStreamManager(); + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 17fe9001b35cc..1580180cc17e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -86,6 +86,7 @@ public void channelUnregistered() { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java index 2b9f1262972e3..0ecfb75591e10 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java @@ -38,17 +38,26 @@ public SaslBootstrap(String secretKeyId, SecretKeyHolder secretKeyHolder) { public void doBootstrap(TransportClient client) { SparkSaslClient saslClient = new SparkSaslClient(secretKeyId, secretKeyHolder); - byte[] payload = saslClient.firstToken(); - - while (!saslClient.isComplete()) { - SaslMessage msg = new SaslMessage(secretKeyId, payload); - logger.info("Sending msg {} {}", secretKeyId, payload.length); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); - msg.encode(buf); - - byte[] response = client.sendRpcSync(buf.array(), 300000); - logger.info("Got response {} {}", secretKeyId, response.length); - payload = saslClient.response(response); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(secretKeyId, payload); + logger.info("Sending msg {} {}", secretKeyId, payload.length); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), 300000); + logger.info("Got response {} {}", secretKeyId, response.length); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } } } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 2c5d58fe07f87..f03ba31b121aa 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -36,15 +36,18 @@ * RPC Handler which performs SASL authentication before delegating to a child RPC handler. * The delegate will only receive messages if the given connection has been successfully * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. */ -public class SaslRpcHandler implements RpcHandler { +public class SaslRpcHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); private final RpcHandler delegate; private final SecretKeyHolder secretKeyHolder; // TODO: Invalidate channels that have closed! - private final ConcurrentMap channelAuthenticationMap; + private final ConcurrentMap channelAuthenticationMap; public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { this.delegate = delegate; @@ -54,9 +57,7 @@ public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String channelKey = client.getChannelKey(); - - SparkSaslServer saslServer = channelAuthenticationMap.get(channelKey); + SparkSaslServer saslServer = channelAuthenticationMap.get(client); if (saslServer != null && saslServer.isComplete()) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); @@ -66,13 +67,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); if (saslServer == null) { + // First message in the handshake, setup the necessary state. saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); - channelAuthenticationMap.put(channelKey, saslServer); + channelAuthenticationMap.put(client, saslServer); } byte[] response = saslServer.response(saslMessage.payload); if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", channelKey); + logger.debug("SASL authentication successful for channel {}", client); } callback.onSuccess(response); } @@ -81,5 +83,13 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback public StreamManager getStreamManager() { return delegate.getStreamManager(); } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a9dff31decc83..cd3fea85b19a4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -41,7 +41,7 @@ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- * level shuffle block. */ -public class ExternalShuffleBlockHandler implements RpcHandler { +public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); private final ExternalShuffleBlockManager blockManager; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c653dda5d20be..73dd8e264e580 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -158,7 +158,7 @@ public void testNoSaslServer() { } /** RPC handler which simply responds with the message it received. */ - public static class TestRpcHandler implements RpcHandler { + public static class TestRpcHandler extends RpcHandler { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { callback.onSuccess(message); From 151b3c5a599ab6c51f4b9edccf8a43e4cb9e87e4 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 3 Nov 2014 22:28:32 -0800 Subject: [PATCH 04/10] Add docs, timeout config, better failure handling --- .../scala/org/apache/spark/SparkConf.scala | 4 +-- .../netty/NettyBlockTransferService.scala | 7 +++-- .../apache/spark/storage/BlockManager.scala | 4 +-- .../client/TransportClientBootstrap.java | 12 +++++++- .../client/TransportClientFactory.java | 21 ++++++++++---- .../spark/network/util/TransportConf.java | 3 ++ .../spark/network/sasl/SaslBootstrap.java | 28 +++++++++++++------ .../network/sasl/SaslIntegrationSuite.java | 6 ++-- 8 files changed, 61 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a205d88d1877e..4c6c86c7bad78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -218,8 +218,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getAll.filter { case (k, _) => isAkkaConf(k) } /** - * Returns the Spark application id, valid in the Driver after TaskScheduler registration in the - * driver and from the start in the Executor. + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. */ def getAppId: String = get("spark.app.id") diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70b04f39400d2..017666965ddc5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import org.apache.spark.network.util.TransportConf + import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} @@ -41,6 +43,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -53,10 +56,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage (nettyRpcHandler, None) } else { (new SaslRpcHandler(nettyRpcHandler, securityManager), - Some(new SaslBootstrap(conf.getAppId, securityManager))) + Some(new SaslBootstrap(transportConf, conf.getAppId, securityManager))) } } - transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) + transportContext = new TransportContext(transportConf, rpcHandler) clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() logInfo("Server created on " + server.getPort) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 414f2ec5afa4c..655d16c65c8b5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -176,8 +176,8 @@ private[spark] class BlockManager( * the appId may not be known at BlockManager instantiation time (in particular for the driver, * where it is only learned after registration with the TaskScheduler). * - * This method initializes the BlockTransferService and ShuffleClient registers with the - * BlockManagerMaster, starts theBlockManagerWorker actor, and registers with a local shuffle + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java index 130b080db0f65..65e8020e34121 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -17,6 +17,16 @@ package org.apache.spark.network.client; +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ public interface TransportClientBootstrap { - public void doBootstrap(TransportClient client); + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 9897e63e2511b..8aa860dbd6d92 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -26,6 +26,8 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; @@ -42,6 +44,7 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -60,8 +63,8 @@ public class TransportClientFactory implements Closeable { private final TransportContext context; private final TransportConf conf; - private final ConcurrentHashMap connectionPool; private final List clientBootstraps; + private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private EventLoopGroup workerGroup; @@ -69,10 +72,10 @@ public class TransportClientFactory implements Closeable { public TransportClientFactory( TransportContext context, List clientBootstraps) { - this.context = context; + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap(); - this.clientBootstraps = Lists.newArrayList(clientBootstraps); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); @@ -138,10 +141,16 @@ public void initChannel(SocketChannel ch) { TransportClient client = clientRef.get(); assert client != null : "Channel future completed successfully with null client"; - logger.debug("Connection to {} successful, running bootstraps...", address); // Execute any client bootstraps synchronously before marking the Client as successful. - for (TransportClientBootstrap clientBootstrap : clientBootstraps) { - clientBootstrap.doBootstrap(client); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch Exception as the bootstrap may be written in Scala + logger.error("Exception while bootstrapping client", e); + client.close(); + throw Throwables.propagate(e); } logger.debug("Successfully executed {} bootstraps for {}", clientBootstraps.size(), address); diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index a68f38e0e94c9..823790dd3c66f 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -55,4 +55,7 @@ public int connectionTimeoutMs() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java index 0ecfb75591e10..8dea863e6f4e4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java @@ -24,31 +24,43 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ public class SaslBootstrap implements TransportClientBootstrap { private final Logger logger = LoggerFactory.getLogger(SaslBootstrap.class); - private final String secretKeyId; + private final TransportConf conf; + private final String appId; private final SecretKeyHolder secretKeyHolder; - public SaslBootstrap(String secretKeyId, SecretKeyHolder secretKeyHolder) { - this.secretKeyId = secretKeyId; + public SaslBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; this.secretKeyHolder = secretKeyHolder; } + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ public void doBootstrap(TransportClient client) { - SparkSaslClient saslClient = new SparkSaslClient(secretKeyId, secretKeyHolder); + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); try { byte[] payload = saslClient.firstToken(); while (!saslClient.isComplete()) { - SaslMessage msg = new SaslMessage(secretKeyId, payload); - logger.info("Sending msg {} {}", secretKeyId, payload.length); + SaslMessage msg = new SaslMessage(appId, payload); + logger.info("Sending msg {} {}", appId, payload.length); ByteBuf buf = Unpooled.buffer(msg.encodedLength()); msg.encode(buf); - byte[] response = client.sendRpcSync(buf.array(), 300000); - logger.info("Got response {} {}", secretKeyId, response.length); + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + logger.info("Got response {} {}", appId, response.length); payload = saslClient.response(response); } } finally { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 73dd8e264e580..695429a2232f0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -96,7 +96,7 @@ public void afterEach() { public void testGoodClient() { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslBootstrap("app-id", new TestSecretKeyHolder("good-key")))); + new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; @@ -108,7 +108,7 @@ public void testGoodClient() { public void testBadClient() { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslBootstrap("app-id", new TestSecretKeyHolder("bad-key")))); + new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); try { // Bootstrap should fail on startup. @@ -146,7 +146,7 @@ public void testNoSaslServer() { TransportContext context = new TransportContext(conf, handler); clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslBootstrap("app-id", new TestSecretKeyHolder("key")))); + new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); TransportServer server = context.createServer(); try { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); From 79973cb1318812488f183c5a049c4018253937fc Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 3 Nov 2014 22:31:24 -0800 Subject: [PATCH 05/10] Remove unused file --- .../network/sasl/TestSecretKeyHolder.java | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java deleted file mode 100644 index e3a69c25cf60d..0000000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/TestSecretKeyHolder.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; From 785bbde163bd6c4656192727f8b539c21d9eb83f Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 3 Nov 2014 23:45:56 -0800 Subject: [PATCH 06/10] Cleanup --- .../java/org/apache/spark/network/sasl/SaslBootstrap.java | 2 -- .../java/org/apache/spark/network/sasl/SaslMessage.java | 4 ++-- .../java/org/apache/spark/network/sasl/SaslRpcHandler.java | 6 ++++-- .../java/org/apache/spark/network/sasl/SecretKeyHolder.java | 2 +- .../java/org/apache/spark/network/sasl/SparkSaslServer.java | 1 - 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java index 8dea863e6f4e4..d8566749780a8 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java @@ -55,12 +55,10 @@ public void doBootstrap(TransportClient client) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - logger.info("Sending msg {} {}", appId, payload.length); ByteBuf buf = Unpooled.buffer(msg.encodedLength()); msg.encode(buf); byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); - logger.info("Got response {} {}", appId, response.length); payload = saslClient.response(response); } } finally { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 2a28ae0f63e58..5b77e18c26bf4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -25,7 +25,7 @@ /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged * with the given appId. This appId allows a single SaslRpcHandler to multiplex different - * applications who may be using different sets of credentials. + * applications which may be using different sets of credentials. */ class SaslMessage implements Encodable { @@ -51,7 +51,7 @@ public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); byte[] idBytes = appId.getBytes(Charsets.UTF_8); buf.writeInt(idBytes.length); - buf.writeBytes(appId.getBytes(Charsets.UTF_8)); + buf.writeBytes(idBytes); buf.writeInt(payload.length); buf.writeBytes(payload); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index f03ba31b121aa..6701c91e4f99b 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -43,10 +43,13 @@ public class SaslRpcHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + /** RpcHandler we will delegate for authenticated connections. */ private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ private final SecretKeyHolder secretKeyHolder; - // TODO: Invalidate channels that have closed! + /** Maps each channel to its SASL authentication state. */ private final ConcurrentMap channelAuthenticationMap; public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { @@ -92,4 +95,3 @@ public void connectionTerminated(TransportClient client) { } } } - diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java index aaf50d2b62dd2..81d5766794688 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -18,7 +18,7 @@ package org.apache.spark.network.sasl; /** - * Interface for getting a secret key associated with some appId. + * Interface for getting a secret key associated with some application. */ public interface SecretKeyHolder { /** diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 937370355dfa8..d8b164c5ee637 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -80,7 +80,6 @@ public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, new DigestCallbackHandler()); } catch (SaslException e) { - System.exit(0); throw Throwables.propagate(e); } } From a6b95f1c5f12040f63b2b6387c6fab5103cd31ce Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 09:05:21 -0800 Subject: [PATCH 07/10] Address comments --- .../network/netty/NettyBlockTransferService.scala | 6 ++---- .../network/client/TransportClientFactory.java | 14 ++++++++++---- ...SaslBootstrap.java => SaslClientBootstrap.java} | 7 ++++--- .../apache/spark/network/sasl/SaslRpcHandler.java | 2 +- .../spark/network/sasl/SaslIntegrationSuite.java | 6 +++--- 5 files changed, 20 insertions(+), 15 deletions(-) rename network/shuffle/src/main/java/org/apache/spark/network/sasl/{SaslBootstrap.java => SaslClientBootstrap.java} (90%) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 017666965ddc5..0d1fc81d2a16f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,8 +17,6 @@ package org.apache.spark.network.netty -import org.apache.spark.network.util.TransportConf - import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} @@ -27,7 +25,7 @@ import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} -import org.apache.spark.network.sasl.{SecretKeyHolder, SaslRpcHandler, SaslBootstrap} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer @@ -56,7 +54,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage (nettyRpcHandler, None) } else { (new SaslRpcHandler(nettyRpcHandler, securityManager), - Some(new SaslBootstrap(transportConf, conf.getAppId, securityManager))) + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) } } transportContext = new TransportContext(transportConf, rpcHandler) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 8aa860dbd6d92..1723fed307257 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -130,6 +130,7 @@ public void initChannel(SocketChannel ch) { }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new RuntimeException( @@ -142,25 +143,30 @@ public void initChannel(SocketChannel ch) { assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { clientBootstrap.doBootstrap(client); } - } catch (Exception e) { // catch Exception as the bootstrap may be written in Scala - logger.error("Exception while bootstrapping client", e); + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); client.close(); throw Throwables.propagate(e); } + long postBootstrap = System.currentTimeMillis(); - logger.debug("Successfully executed {} bootstraps for {}", clientBootstraps.size(), address); // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); client.close(); return oldClient; } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java similarity index 90% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java rename to network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index d8566749780a8..7bc91e375371f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,14 +30,14 @@ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. */ -public class SaslBootstrap implements TransportClientBootstrap { - private final Logger logger = LoggerFactory.getLogger(SaslBootstrap.class); +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); private final TransportConf conf; private final String appId; private final SecretKeyHolder secretKeyHolder; - public SaslBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { this.conf = conf; this.appId = appId; this.secretKeyHolder = secretKeyHolder; @@ -48,6 +48,7 @@ public SaslBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKey * challenge-response tokens until we either successfully authenticate or throw an exception * due to mismatch. */ + @Override public void doBootstrap(TransportClient client) { SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); try { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 6701c91e4f99b..3777a18e33f78 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -43,7 +43,7 @@ public class SaslRpcHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); - /** RpcHandler we will delegate for authenticated connections. */ + /** RpcHandler we will delegate to for authenticated connections. */ private final RpcHandler delegate; /** Class which provides secret keys which are shared by server and client on a per-app basis. */ diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 695429a2232f0..84781207861ed 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -96,7 +96,7 @@ public void afterEach() { public void testGoodClient() { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; @@ -108,7 +108,7 @@ public void testGoodClient() { public void testBadClient() { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); try { // Bootstrap should fail on startup. @@ -146,7 +146,7 @@ public void testNoSaslServer() { TransportContext context = new TransportContext(conf, handler); clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); TransportServer server = context.createServer(); try { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); From eb9f0659109aba422021bdec6770b4beb0f74a59 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 10:18:14 -0800 Subject: [PATCH 08/10] Improve documentation and add end-to-end test at Spark-level --- .../org/apache/spark/SecurityManager.scala | 9 +- .../NettyBlockTransferSecuritySuite.scala | 161 ++++++++++++++++++ .../spark/network/sasl/SparkSaslServer.java | 9 +- 3 files changed, 172 insertions(+), 7 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index a77e8e788d730..7c8bda125c5c7 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -85,7 +85,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -99,7 +99,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -108,6 +108,10 @@ import org.apache.spark.network.sasl.SecretKeyHolder * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -347,6 +351,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with override def getSecretKey(appId: String): String = { val myAppId = sparkConf.getAppId + println("App id: " + appId + " / " + myAppId) require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") getSecretKey() } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala new file mode 100644 index 0000000000000..bed0ed9d713dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio._ +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.util.{Failure, Success, Try} + +import org.apache.commons.io.IOUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.{SecurityManager, SparkConf} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} + +class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { + test("security default off") { + testConnection(new SparkConf, new SparkConf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on same password") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on mismatch password") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate.secret", "bad") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Mismatched response") + } + } + + test("security mismatch auth off on server") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "false") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC + } + } + + test("security mismatch auth off on client") { + val conf0 = new SparkConf() + .set("spark.authenticate", "false") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "true") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Expected SaslMessage") + } + } + + test("security mismatch app ids") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.app.id", "other-id") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") + } + } + + /** + * Creates two servers with different configurations and sees if they can talk. + * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed + * properly. We will throw an out-of-band exception if something other than that goes wrong. + */ + private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = { + val blockManager = mock[BlockDataManager] + val blockId = ShuffleBlockId(0, 1, 2) + val blockString = "Hello, world!" + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + + val securityManager0 = new SecurityManager(conf0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0) + exec0.init(blockManager) + + val securityManager1 = new SecurityManager(conf1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1) + exec1.init(blockManager) + + val result = fetchBlock(exec0, exec1, "1", blockId) match { + case Success(buf) => + IOUtils.toString(buf.createInputStream()) should equal(blockString) + buf.release() + Success() + case Failure(t) => + Failure(t) + } + exec0.close() + exec1.close() + result + } + + /** Synchronously fetches a single block, acting as the given executor fetching from another. */ + private def fetchBlock( + self: BlockTransferService, + from: BlockTransferService, + execId: String, + blockId: BlockId): Try[ManagedBuffer] = { + + val promise = Promise[ManagedBuffer]() + + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + promise.failure(exception) + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + promise.success(data.retain()) + } + }) + + Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + promise.future.value.get + } +} + diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index d8b164c5ee637..2c0ce40c75e80 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -31,6 +31,7 @@ import java.util.Map; import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.io.BaseEncoding; @@ -157,15 +158,13 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { - if (password != null) { - return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); - } else { - return new char[0]; - } + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); } } From 44f8410b5f9a855e27d1bc91c4b605bfcff7641a Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 10:19:51 -0800 Subject: [PATCH 09/10] Delete documentation - muahaha! --- docs/security.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/security.md b/docs/security.md index ec0523184d665..1e206a139fb72 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI From 3481718baf66f4c02568f08d7011af19a2c47f4b Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 10:20:32 -0800 Subject: [PATCH 10/10] Delete rogue println --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 7c8bda125c5c7..dee935ffad51f 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -351,7 +351,6 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with override def getSecretKey(appId: String): String = { val myAppId = sparkConf.getAppId - println("App id: " + appId + " / " + myAppId) require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") getSecretKey() }