-
Notifications
You must be signed in to change notification settings - Fork 419
[CELEBORN-1314] add capacity-bounded inbox for rpc endpoint #2366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
794b57a
ba43e9f
ee2adbe
85e1195
511c10b
2255ab5
83d581b
fdadfef
2dacb59
12a939e
8c7f590
732e6ba
8803fb6
839d41e
100d2f9
1e7f228
713393c
11831be
3927948
11ae3cc
8176857
769b0c6
56e14a2
e085d2d
b62bce6
37b8e1f
63c7a42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,13 @@ | |
|
|
||
| package org.apache.celeborn.common.rpc.netty | ||
|
|
||
| import java.util.concurrent.atomic.AtomicLong | ||
| import java.util.concurrent.locks.ReentrantLock | ||
| import javax.annotation.concurrent.GuardedBy | ||
|
|
||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.celeborn.common.CelebornConf | ||
| import org.apache.celeborn.common.exception.CelebornException | ||
| import org.apache.celeborn.common.internal.Logging | ||
| import org.apache.celeborn.common.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} | ||
|
|
@@ -64,14 +67,21 @@ private[celeborn] case class RemoteProcessConnectionError( | |
| */ | ||
| private[celeborn] class Inbox( | ||
| val endpointRef: NettyRpcEndpointRef, | ||
| val endpoint: RpcEndpoint) | ||
| extends Logging { | ||
| val endpoint: RpcEndpoint, | ||
| val conf: CelebornConf) extends Logging { | ||
|
|
||
| inbox => // Give this an alias so we can use it more clearly in closures. | ||
|
|
||
| private[netty] val capacity = conf.get(CelebornConf.RPC_INBOX_CAPACITY) | ||
|
|
||
| private[netty] val inboxLock = new ReentrantLock() | ||
| private[netty] val isFull = inboxLock.newCondition() | ||
|
|
||
| @GuardedBy("this") | ||
| protected val messages = new java.util.LinkedList[InboxMessage]() | ||
|
|
||
| private val messageCount = new AtomicLong(0) | ||
|
|
||
| /** True if the inbox (and its associated endpoint) is stopped. */ | ||
| @GuardedBy("this") | ||
| private var stopped = false | ||
|
|
@@ -85,84 +95,130 @@ private[celeborn] class Inbox( | |
| private var numActiveThreads = 0 | ||
|
|
||
| // OnStart should be the first message to process | ||
| inbox.synchronized { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| messages.add(OnStart) | ||
| messageCount.incrementAndGet() | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
|
|
||
| def addMessage(message: InboxMessage): Unit = { | ||
| messages.add(message) | ||
| messageCount.incrementAndGet() | ||
| signalNotFull() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it necessary to call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is a way for better efficiency? I saw LinkedBlockingQueue does signal this after putting an element in https://github.com/openjdk-mirror/jdk7u-jdk/blob/master/src/share/classes/java/util/concurrent/LinkedBlockingQueue.java#L354 |
||
| logDebug(s"queue length of ${messageCount.get()} ") | ||
| } | ||
|
|
||
| private def processInternal(dispatcher: Dispatcher, message: InboxMessage): Unit = { | ||
| message match { | ||
| case RpcMessage(_sender, content, context) => | ||
| try { | ||
| endpoint.receiveAndReply(context).applyOrElse[Any, Unit]( | ||
| content, | ||
| { msg => | ||
| throw new CelebornException(s"Unsupported message $message from ${_sender}") | ||
| }) | ||
| } catch { | ||
| case e: Throwable => | ||
| context.sendFailure(e) | ||
| // Throw the exception -- this exception will be caught by the safelyCall function. | ||
| // The endpoint's onError function will be called. | ||
| throw e | ||
| } | ||
|
|
||
| case OneWayMessage(_sender, content) => | ||
| endpoint.receive.applyOrElse[Any, Unit]( | ||
| content, | ||
| { msg => | ||
| throw new CelebornException(s"Unsupported message $message from ${_sender}") | ||
| }) | ||
|
|
||
| case OnStart => | ||
| endpoint.onStart() | ||
| if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| if (!stopped) { | ||
| enableConcurrent = true | ||
| } | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
|
|
||
| case OnStop => | ||
| val activeThreads = | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| inbox.numActiveThreads | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| assert( | ||
| activeThreads == 1, | ||
| s"There should be only a single active thread but found $activeThreads threads.") | ||
| dispatcher.removeRpcEndpointRef(endpoint) | ||
| endpoint.onStop() | ||
| assert(isEmpty, "OnStop should be the last message") | ||
|
|
||
| case RemoteProcessConnected(remoteAddress) => | ||
| endpoint.onConnected(remoteAddress) | ||
|
|
||
| case RemoteProcessDisconnected(remoteAddress) => | ||
| endpoint.onDisconnected(remoteAddress) | ||
|
|
||
| case RemoteProcessConnectionError(cause, remoteAddress) => | ||
| endpoint.onNetworkError(cause, remoteAddress) | ||
| } | ||
| } | ||
|
|
||
| private[netty] def waitOnFull(): Unit = { | ||
| if (capacity > 0 && !stopped) { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| while (messageCount.get() >= capacity) { | ||
| isFull.await() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, the while loop holds
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nah, java's wait/await will automatically release its associated lock if it is called
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation :) I think we should use |
||
| } | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def signalNotFull(): Unit = { | ||
| // when this is called we assume putLock already being called | ||
| require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull without holding lock") | ||
| if (capacity > 0 && messageCount.get() < capacity) { | ||
| isFull.signal() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Process stored messages. | ||
| */ | ||
| def process(dispatcher: Dispatcher): Unit = { | ||
| var message: InboxMessage = null | ||
| inbox.synchronized { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| if (!enableConcurrent && numActiveThreads != 0) { | ||
| return | ||
| } | ||
| message = messages.poll() | ||
| if (message != null) { | ||
| numActiveThreads += 1 | ||
| messageCount.decrementAndGet() | ||
| signalNotFull() | ||
| } else { | ||
| return | ||
| } | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
|
|
||
| while (true) { | ||
| safelyCall(endpoint, endpointRef.name) { | ||
| message match { | ||
| case RpcMessage(_sender, content, context) => | ||
| try { | ||
| endpoint.receiveAndReply(context).applyOrElse[Any, Unit]( | ||
| content, | ||
| { msg => | ||
| throw new CelebornException(s"Unsupported message $message from ${_sender}") | ||
| }) | ||
| } catch { | ||
| case e: Throwable => | ||
| context.sendFailure(e) | ||
| // Throw the exception -- this exception will be caught by the safelyCall function. | ||
| // The endpoint's onError function will be called. | ||
| throw e | ||
| } | ||
|
|
||
| case OneWayMessage(_sender, content) => | ||
| endpoint.receive.applyOrElse[Any, Unit]( | ||
| content, | ||
| { msg => | ||
| throw new CelebornException(s"Unsupported message $message from ${_sender}") | ||
| }) | ||
|
|
||
| case OnStart => | ||
| endpoint.onStart() | ||
| if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { | ||
| inbox.synchronized { | ||
| if (!stopped) { | ||
| enableConcurrent = true | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case OnStop => | ||
| val activeThreads = inbox.synchronized { | ||
| inbox.numActiveThreads | ||
| } | ||
| assert( | ||
| activeThreads == 1, | ||
| s"There should be only a single active thread but found $activeThreads threads.") | ||
| dispatcher.removeRpcEndpointRef(endpoint) | ||
| endpoint.onStop() | ||
| assert(isEmpty, "OnStop should be the last message") | ||
|
|
||
| case RemoteProcessConnected(remoteAddress) => | ||
| endpoint.onConnected(remoteAddress) | ||
|
|
||
| case RemoteProcessDisconnected(remoteAddress) => | ||
| endpoint.onDisconnected(remoteAddress) | ||
|
|
||
| case RemoteProcessConnectionError(cause, remoteAddress) => | ||
| endpoint.onNetworkError(cause, remoteAddress) | ||
| } | ||
| processInternal(dispatcher, message) | ||
| } | ||
|
|
||
| inbox.synchronized { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| // "enableConcurrent" will be set to false after `onStop` is called, so we should check it | ||
| // every time. | ||
| if (!enableConcurrent && numActiveThreads != 1) { | ||
|
|
@@ -174,37 +230,56 @@ private[celeborn] class Inbox( | |
| if (message == null) { | ||
| numActiveThreads -= 1 | ||
| return | ||
| } else { | ||
| messageCount.decrementAndGet() | ||
| signalNotFull() | ||
| } | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def post(message: InboxMessage): Unit = inbox.synchronized { | ||
| if (stopped) { | ||
| // We already put "OnStop" into "messages", so we should drop further messages | ||
| onDrop(message) | ||
| } else { | ||
| messages.add(message) | ||
| false | ||
| def post(message: InboxMessage): Unit = { | ||
CodingCat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| if (stopped) { | ||
| // We already put "OnStop" into "messages", so we should drop further messages | ||
| onDrop(message) | ||
| } else { | ||
| addMessage(message) | ||
| } | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
|
|
||
| def stop(): Unit = inbox.synchronized { | ||
| // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last | ||
| // message | ||
| if (!stopped) { | ||
| // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only | ||
| // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources | ||
| // safely. | ||
| enableConcurrent = false | ||
| stopped = true | ||
| messages.add(OnStop) | ||
| // Note: The concurrent events in messages will be processed one by one. | ||
| def stop(): Unit = { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last | ||
| // message | ||
| if (!stopped) { | ||
| // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only | ||
| // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources | ||
| // safely. | ||
| enableConcurrent = false | ||
| stopped = true | ||
| addMessage(OnStop) | ||
| // Note: The concurrent events in messages will be processed one by one. | ||
| } | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
|
|
||
| def isEmpty: Boolean = inbox.synchronized { | ||
| messages.isEmpty | ||
| def isEmpty: Boolean = { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| messages.isEmpty | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -222,10 +297,13 @@ private[celeborn] class Inbox( | |
| endpoint: RpcEndpoint, | ||
| endpointRefName: String)(action: => Unit): Unit = { | ||
| def dealWithFatalError(fatal: Throwable): Unit = { | ||
| inbox.synchronized { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| assert(numActiveThreads > 0, "The number of active threads should be positive.") | ||
| // Should reduce the number of active threads before throw the error. | ||
| numActiveThreads -= 1 | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| logError( | ||
| s"An error happened while processing message in the inbox for $endpointRefName", | ||
|
|
@@ -254,8 +332,11 @@ private[celeborn] class Inbox( | |
|
|
||
| // exposed only for testing | ||
| def getNumActiveThreads: Int = { | ||
| inbox.synchronized { | ||
| try { | ||
| inboxLock.lockInterruptibly() | ||
| inbox.numActiveThreads | ||
| } finally { | ||
| inboxLock.unlock() | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we need to increment
messageCountafter addingOnStart, because it will decrement when process the message.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated