Skip to content

Commit

Permalink
[SPARK-11098][CORE] Add Outbox to cache the sending messages to resol…
Browse files Browse the repository at this point in the history
…ve the message disorder issue

The current NettyRpc has a message order issue because it uses a thread pool to send messages. E.g., running the following two lines in the same thread,

```
ref.send("A")
ref.send("B")
```

The remote endpoint may see "B" before "A" because sending "A" and "B" are in parallel.
To resolve this issue, this PR added an outbox for each connection, and if we are connecting to the remote node when sending messages, just cache the sending messages in the outbox and send them one by one when the connection is established.

Author: zsxwing <zsxwing@gmail.com>

Closes #9197 from zsxwing/rpc-outbox.
  • Loading branch information
zsxwing authored and rxin committed Oct 23, 2015
1 parent 34e71c6 commit a88c66c
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 57 deletions.
145 changes: 88 additions & 57 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io._
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
Expand Down Expand Up @@ -70,12 +71,30 @@ private[netty] class NettyRpcEnv(
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
// to implement non-blocking send/ask.
// TODO: a non-blocking TransportClientFactory.createClient in future
private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 64))

@volatile private var server: TransportServer = _

private val stopped = new AtomicBoolean(false)

/**
* A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
* we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
*/
private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()

/**
* Remove the address's Outbox and stop it.
*/
private[netty] def removeOutbox(address: RpcAddress): Unit = {
val outbox = outboxes.remove(address)
if (outbox != null) {
outbox.stop()
}
}

def start(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
Expand Down Expand Up @@ -116,6 +135,30 @@ private[netty] class NettyRpcEnv(
dispatcher.stop(endpointRef)
}

private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = {
val targetOutbox = {
val outbox = outboxes.get(address)
if (outbox == null) {
val newOutbox = new Outbox(this, address)
val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
if (oldOutbox == null) {
newOutbox
} else {
oldOutbox
}
} else {
outbox
}
}
if (stopped.get) {
// It's possible that we put `targetOutbox` after stopping. So we need to clean it.
outboxes.remove(address)
targetOutbox.stop()
} else {
targetOutbox.send(message)
}
}

private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
Expand All @@ -127,37 +170,28 @@ private[netty] class NettyRpcEnv(
val ack = response.asInstanceOf[Ack]
logTrace(s"Received ack from ${ack.sender}")
case Failure(e) =>
logError(s"Exception when sending $message", e)
logWarning(s"Exception when sending $message", e)
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
clientConnectionExecutor.execute(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
client.sendRpc(serialize(message), new RpcResponseCallback {

override def onFailure(e: Throwable): Unit = {
logError(s"Exception when sending $message", e)
}

override def onSuccess(response: Array[Byte]): Unit = {
val ack = deserialize[Ack](response)
logDebug(s"Receive ack from ${ack.sender}")
}
})
}
})
} catch {
case e: RejectedExecutionException =>
// `send` after shutting clientConnectionExecutor down, ignore it
logWarning(s"Cannot send $message because RpcEnv is stopped")
}
postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {

override def onFailure(e: Throwable): Unit = {
logWarning(s"Exception when sending $message", e)
}

override def onSuccess(response: Array[Byte]): Unit = {
val ack = deserialize[Ack](response)
logDebug(s"Receive ack from ${ack.sender}")
}
}))
}
}

private[netty] def createClient(address: RpcAddress): TransportClient = {
clientFactory.createClient(address.host, address.port)
}

private[netty] def ask(message: RequestMessage): Future[Any] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
Expand All @@ -180,39 +214,25 @@ private[netty] class NettyRpcEnv(
}
}(ThreadUtils.sameThread)
} else {
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
clientConnectionExecutor.execute(new Runnable {
override def run(): Unit = {
val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
client.sendRpc(serialize(message), new RpcResponseCallback {

override def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
logWarning("Ignore Exception", e)
}
}

override def onSuccess(response: Array[Byte]): Unit = {
val reply = deserialize[AskResponse](response)
if (reply.reply.isInstanceOf[RpcFailure]) {
if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
logWarning(s"Ignore failure: ${reply.reply}")
}
} else if (!promise.trySuccess(reply.reply)) {
logWarning(s"Ignore message: ${reply}")
}
}
})
}
})
} catch {
case e: RejectedExecutionException =>
postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {

override def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
logWarning(s"Ignore failure", e)
logWarning("Ignore Exception", e)
}
}
}

override def onSuccess(response: Array[Byte]): Unit = {
val reply = deserialize[AskResponse](response)
if (reply.reply.isInstanceOf[RpcFailure]) {
if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
logWarning(s"Ignore failure: ${reply.reply}")
}
} else if (!promise.trySuccess(reply.reply)) {
logWarning(s"Ignore message: ${reply}")
}
}
}))
}
promise.future
}
Expand Down Expand Up @@ -245,6 +265,16 @@ private[netty] class NettyRpcEnv(
}

private def cleanup(): Unit = {
if (!stopped.compareAndSet(false, true)) {
return
}

val iter = outboxes.values().iterator()
while (iter.hasNext()) {
val outbox = iter.next()
outboxes.remove(outbox.address)
outbox.stop()
}
if (timeoutScheduler != null) {
timeoutScheduler.shutdownNow()
}
Expand Down Expand Up @@ -463,6 +493,7 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
val messageOpt: Option[RemoteProcessDisconnected] =
synchronized {
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
Expand Down

0 comments on commit a88c66c

Please sign in to comment.