diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index d6acd09a419..bb6b96e93a9 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -400,6 +400,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se new RpcTimeout(get(RPC_LOOKUP_TIMEOUT).milli, RPC_LOOKUP_TIMEOUT.key) def rpcAskTimeout: RpcTimeout = new RpcTimeout(get(RPC_ASK_TIMEOUT).milli, RPC_ASK_TIMEOUT.key) + def rpcInMemoryBoundedInboxCapacity(): Int = { + get(RPC_INBOX_CAPACITY) + } def rpcDispatcherNumThreads(availableCores: Int): Int = { val num = get(RPC_DISPATCHER_THREADS) if (num != 0) num else availableCores @@ -1592,6 +1595,17 @@ object CelebornConf extends Logging { .intConf .createWithDefault(0) + val RPC_INBOX_CAPACITY: ConfigEntry[Int] = + buildConf("celeborn.rpc.inbox.capacity") + .categories("network") + .doc("Specifies size of the in memory bounded capacity.") + .version("0.5.0") + .intConf + .checkValue( + v => v >= 0, + "the capacity of inbox must be no less than 0, 0 means no limitation") + .createWithDefault(0) + val RPC_ROLE_DISPATHER_THREADS: ConfigEntry[Int] = buildConf("celeborn..rpc.dispatcher.threads") .categories("network") diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala index 391b6418640..b8a93bda0ce 100644 --- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala +++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Dispatcher.scala @@ -39,7 +39,8 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val name: String, val endpoint: RpcEndpoint, val ref: NettyRpcEndpointRef) { - val inbox = new Inbox(ref, endpoint) + val celebornConf = nettyEnv.celebornConf + val inbox = new Inbox(ref, endpoint, celebornConf) } private val endpoints: ConcurrentMap[String, EndpointData] = @@ -157,7 +158,14 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { endpointName: String, message: InboxMessage, callbackIfStopped: Exception => Unit): Unit = { + val data = synchronized { + endpoints.get(endpointName) + } + if (data != null) { + data.inbox.waitOnFull() + } val error = synchronized { + // double check val data = endpoints.get(endpointName) if (stopped) { Some(new RpcEnvStoppedException()) diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala index 09cdd08e256..75025245652 100644 --- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala +++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala @@ -17,10 +17,13 @@ package org.apache.celeborn.common.rpc.netty +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.locks.ReentrantLock import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.CelebornException import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} @@ -64,14 +67,21 @@ private[celeborn] case class RemoteProcessConnectionError( */ private[celeborn] class Inbox( val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) - extends Logging { + val endpoint: RpcEndpoint, + val conf: CelebornConf) extends Logging { inbox => // Give this an alias so we can use it more clearly in closures. + private[netty] val capacity = conf.get(CelebornConf.RPC_INBOX_CAPACITY) + + private[netty] val inboxLock = new ReentrantLock() + private[netty] val isFull = inboxLock.newCondition() + @GuardedBy("this") protected val messages = new java.util.LinkedList[InboxMessage]() + private val messageCount = new AtomicLong(0) + /** True if the inbox (and its associated endpoint) is stopped. */ @GuardedBy("this") private var stopped = false @@ -85,84 +95,130 @@ private[celeborn] class Inbox( private var numActiveThreads = 0 // OnStart should be the first message to process - inbox.synchronized { + try { + inboxLock.lockInterruptibly() messages.add(OnStart) + messageCount.incrementAndGet() + } finally { + inboxLock.unlock() + } + + def addMessage(message: InboxMessage): Unit = { + messages.add(message) + messageCount.incrementAndGet() + signalNotFull() + logDebug(s"queue length of ${messageCount.get()} ") + } + + private def processInternal(dispatcher: Dispatcher, message: InboxMessage): Unit = { + message match { + case RpcMessage(_sender, content, context) => + try { + endpoint.receiveAndReply(context).applyOrElse[Any, Unit]( + content, + { msg => + throw new CelebornException(s"Unsupported message $message from ${_sender}") + }) + } catch { + case e: Throwable => + context.sendFailure(e) + // Throw the exception -- this exception will be caught by the safelyCall function. + // The endpoint's onError function will be called. + throw e + } + + case OneWayMessage(_sender, content) => + endpoint.receive.applyOrElse[Any, Unit]( + content, + { msg => + throw new CelebornException(s"Unsupported message $message from ${_sender}") + }) + + case OnStart => + endpoint.onStart() + if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + try { + inboxLock.lockInterruptibly() + if (!stopped) { + enableConcurrent = true + } + } finally { + inboxLock.unlock() + } + } + + case OnStop => + val activeThreads = + try { + inboxLock.lockInterruptibly() + inbox.numActiveThreads + } finally { + inboxLock.unlock() + } + assert( + activeThreads == 1, + s"There should be only a single active thread but found $activeThreads threads.") + dispatcher.removeRpcEndpointRef(endpoint) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") + + case RemoteProcessConnected(remoteAddress) => + endpoint.onConnected(remoteAddress) + + case RemoteProcessDisconnected(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + + case RemoteProcessConnectionError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) + } + } + + private[netty] def waitOnFull(): Unit = { + if (capacity > 0 && !stopped) { + try { + inboxLock.lockInterruptibly() + while (messageCount.get() >= capacity) { + isFull.await() + } + } finally { + inboxLock.unlock() + } + } + } + + private def signalNotFull(): Unit = { + // when this is called we assume putLock already being called + require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull without holding lock") + if (capacity > 0 && messageCount.get() < capacity) { + isFull.signal() + } } - /** - * Process stored messages. - */ def process(dispatcher: Dispatcher): Unit = { var message: InboxMessage = null - inbox.synchronized { + try { + inboxLock.lockInterruptibly() if (!enableConcurrent && numActiveThreads != 0) { return } message = messages.poll() if (message != null) { numActiveThreads += 1 + messageCount.decrementAndGet() + signalNotFull() } else { return } + } finally { + inboxLock.unlock() } + while (true) { safelyCall(endpoint, endpointRef.name) { - message match { - case RpcMessage(_sender, content, context) => - try { - endpoint.receiveAndReply(context).applyOrElse[Any, Unit]( - content, - { msg => - throw new CelebornException(s"Unsupported message $message from ${_sender}") - }) - } catch { - case e: Throwable => - context.sendFailure(e) - // Throw the exception -- this exception will be caught by the safelyCall function. - // The endpoint's onError function will be called. - throw e - } - - case OneWayMessage(_sender, content) => - endpoint.receive.applyOrElse[Any, Unit]( - content, - { msg => - throw new CelebornException(s"Unsupported message $message from ${_sender}") - }) - - case OnStart => - endpoint.onStart() - if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - inbox.synchronized { - if (!stopped) { - enableConcurrent = true - } - } - } - - case OnStop => - val activeThreads = inbox.synchronized { - inbox.numActiveThreads - } - assert( - activeThreads == 1, - s"There should be only a single active thread but found $activeThreads threads.") - dispatcher.removeRpcEndpointRef(endpoint) - endpoint.onStop() - assert(isEmpty, "OnStop should be the last message") - - case RemoteProcessConnected(remoteAddress) => - endpoint.onConnected(remoteAddress) - - case RemoteProcessDisconnected(remoteAddress) => - endpoint.onDisconnected(remoteAddress) - - case RemoteProcessConnectionError(cause, remoteAddress) => - endpoint.onNetworkError(cause, remoteAddress) - } + processInternal(dispatcher, message) } - - inbox.synchronized { + try { + inboxLock.lockInterruptibly() // "enableConcurrent" will be set to false after `onStop` is called, so we should check it // every time. if (!enableConcurrent && numActiveThreads != 1) { @@ -174,37 +230,56 @@ private[celeborn] class Inbox( if (message == null) { numActiveThreads -= 1 return + } else { + messageCount.decrementAndGet() + signalNotFull() } + } finally { + inboxLock.unlock() } } } - def post(message: InboxMessage): Unit = inbox.synchronized { - if (stopped) { - // We already put "OnStop" into "messages", so we should drop further messages - onDrop(message) - } else { - messages.add(message) - false + def post(message: InboxMessage): Unit = { + try { + inboxLock.lockInterruptibly() + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages + onDrop(message) + } else { + addMessage(message) + } + } finally { + inboxLock.unlock() } } - def stop(): Unit = inbox.synchronized { - // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last - // message - if (!stopped) { - // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only - // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources - // safely. - enableConcurrent = false - stopped = true - messages.add(OnStop) - // Note: The concurrent events in messages will be processed one by one. + def stop(): Unit = { + try { + inboxLock.lockInterruptibly() + // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last + // message + if (!stopped) { + // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only + // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources + // safely. + enableConcurrent = false + stopped = true + addMessage(OnStop) + // Note: The concurrent events in messages will be processed one by one. + } + } finally { + inboxLock.unlock() } } - def isEmpty: Boolean = inbox.synchronized { - messages.isEmpty + def isEmpty: Boolean = { + try { + inboxLock.lockInterruptibly() + messages.isEmpty + } finally { + inboxLock.unlock() + } } /** @@ -222,10 +297,13 @@ private[celeborn] class Inbox( endpoint: RpcEndpoint, endpointRefName: String)(action: => Unit): Unit = { def dealWithFatalError(fatal: Throwable): Unit = { - inbox.synchronized { + try { + inboxLock.lockInterruptibly() assert(numActiveThreads > 0, "The number of active threads should be positive.") // Should reduce the number of active threads before throw the error. numActiveThreads -= 1 + } finally { + inboxLock.unlock() } logError( s"An error happened while processing message in the inbox for $endpointRefName", @@ -254,8 +332,11 @@ private[celeborn] class Inbox( // exposed only for testing def getNumActiveThreads: Int = { - inbox.synchronized { + try { + inboxLock.lockInterruptibly() inbox.numActiveThreads + } finally { + inboxLock.unlock() } } } diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala index a8bc826dd4b..ab86a57e8a1 100644 --- a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala @@ -21,18 +21,40 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfter import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.rpc.{RpcAddress, TestRpcEndpoint} -class InboxSuite extends CelebornFunSuite { +class InboxSuite extends CelebornFunSuite with BeforeAndAfter { - test("post") { - val endpoint = new TestRpcEndpoint + private var inbox: Inbox = _ + private var endpoint: TestRpcEndpoint = _ + + def initInbox[T]( + testRpcEndpoint: TestRpcEndpoint, + onDropOverride: Option[InboxMessage => T]): Inbox = { val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + if (onDropOverride.isEmpty) { + new Inbox(rpcEnvRef, testRpcEndpoint, new CelebornConf()) + } else { + new Inbox(rpcEnvRef, testRpcEndpoint, new CelebornConf()) { + override protected def onDrop(message: InboxMessage): Unit = { + onDropOverride.get(message) + } + } + } + } + + before { + endpoint = new TestRpcEndpoint + inbox = initInbox(endpoint, None) + } + + test("post") { val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(rpcEnvRef, endpoint) val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) @@ -48,11 +70,8 @@ class InboxSuite extends CelebornFunSuite { } test("post: with reply") { - val endpoint = new TestRpcEndpoint - val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(rpcEnvRef, endpoint) val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) @@ -62,16 +81,15 @@ class InboxSuite extends CelebornFunSuite { } test("post: multiple threads") { - val endpoint = new TestRpcEndpoint val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val numDroppedMessages = new AtomicInteger(0) - val inbox = new Inbox(rpcEnvRef, endpoint) { - override def onDrop(message: InboxMessage): Unit = { - numDroppedMessages.incrementAndGet() - } + + val overrideOnDrop = (msg: InboxMessage) => { + numDroppedMessages.incrementAndGet() } + val inbox = initInbox(endpoint, Some(overrideOnDrop)) val exitLatch = new CountDownLatch(10) @@ -102,12 +120,9 @@ class InboxSuite extends CelebornFunSuite { } test("post: Associated") { - val endpoint = new TestRpcEndpoint - val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(rpcEnvRef, endpoint) inbox.post(RemoteProcessConnected(remoteAddress)) inbox.process(dispatcher) @@ -115,13 +130,10 @@ class InboxSuite extends CelebornFunSuite { } test("post: Disassociated") { - val endpoint = new TestRpcEndpoint - val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(rpcEnvRef, endpoint) inbox.post(RemoteProcessDisconnected(remoteAddress)) inbox.process(dispatcher) @@ -129,14 +141,11 @@ class InboxSuite extends CelebornFunSuite { } test("post: AssociationError") { - val endpoint = new TestRpcEndpoint - val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) val cause = new RuntimeException("Oops") - val inbox = new Inbox(rpcEnvRef, endpoint) inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) inbox.process(dispatcher) @@ -146,9 +155,8 @@ class InboxSuite extends CelebornFunSuite { test("should reduce the number of active threads when fatal error happens") { val endpoint = mock(classOf[TestRpcEndpoint]) when(endpoint.receive).thenThrow(new OutOfMemoryError()) - val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(rpcEnvRef, endpoint) + val inbox = initInbox(endpoint, None) inbox.post(OneWayMessage(null, "hi")) intercept[OutOfMemoryError] { inbox.process(dispatcher) diff --git a/docs/configuration/network.md b/docs/configuration/network.md index a295d5353db..e0e3e8e1187 100644 --- a/docs/configuration/network.md +++ b/docs/configuration/network.md @@ -50,6 +50,7 @@ license: | | celeborn.rpc.askTimeout | 60s | false | Timeout for RPC ask operations. It's recommended to set at least `240s` when `HDFS` is enabled in `celeborn.storage.activeTypes` | 0.2.0 | | | celeborn.rpc.connect.threads | 64 | false | | 0.2.0 | | | celeborn.rpc.dispatcher.threads | 0 | false | Threads number of message dispatcher event loop. Default to 0, which is availableCore. | 0.3.0 | celeborn.rpc.dispatcher.numThreads | +| celeborn.rpc.inbox.capacity | 0 | false | Specifies size of the in memory bounded capacity. | 0.5.0 | | | celeborn.rpc.io.threads | <undefined> | false | Netty IO thread number of NettyRpcEnv to handle RPC request. The default threads number is the number of runtime available processors. | 0.2.0 | | | celeborn.rpc.lookupTimeout | 30s | false | Timeout for RPC lookup operations. | 0.2.0 | | | celeborn.shuffle.io.maxChunksBeingTransferred | <undefined> | false | The max number of chunks allowed to be transferred at the same time on shuffle service. Note that new incoming connections will be closed when the max number is hit. The client will retry according to the shuffle retry configs (see `celeborn..io.maxRetries` and `celeborn..io.retryWait`), if those limits are reached the task will fail with fetch failure. | 0.2.0 | |