Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
new RpcTimeout(get(RPC_LOOKUP_TIMEOUT).milli, RPC_LOOKUP_TIMEOUT.key)
def rpcAskTimeout: RpcTimeout =
new RpcTimeout(get(RPC_ASK_TIMEOUT).milli, RPC_ASK_TIMEOUT.key)
def rpcInMemoryBoundedInboxCapacity(): Int = {
get(RPC_INBOX_CAPACITY)
}
def rpcDispatcherNumThreads(availableCores: Int): Int = {
val num = get(RPC_DISPATCHER_THREADS)
if (num != 0) num else availableCores
Expand Down Expand Up @@ -1592,6 +1595,17 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(0)

val RPC_INBOX_CAPACITY: ConfigEntry[Int] =
buildConf("celeborn.rpc.inbox.capacity")
.categories("network")
.doc("Specifies size of the in memory bounded capacity.")
.version("0.5.0")
.intConf
.checkValue(
v => v >= 0,
"the capacity of inbox must be no less than 0, 0 means no limitation")
.createWithDefault(0)

val RPC_ROLE_DISPATHER_THREADS: ConfigEntry[Int] =
buildConf("celeborn.<role>.rpc.dispatcher.threads")
.categories("network")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val name: String,
val endpoint: RpcEndpoint,
val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
val celebornConf = nettyEnv.celebornConf
val inbox = new Inbox(ref, endpoint, celebornConf)
}

private val endpoints: ConcurrentMap[String, EndpointData] =
Expand Down Expand Up @@ -157,7 +158,14 @@ private[celeborn] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
endpointName: String,
message: InboxMessage,
callbackIfStopped: Exception => Unit): Unit = {
val data = synchronized {
endpoints.get(endpointName)
}
if (data != null) {
data.inbox.waitOnFull()
}
val error = synchronized {
// double check
val data = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
Expand Down
249 changes: 165 additions & 84 deletions common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.celeborn.common.rpc.netty

import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy

import scala.util.control.NonFatal

import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.CelebornException
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
Expand Down Expand Up @@ -64,14 +67,21 @@ private[celeborn] case class RemoteProcessConnectionError(
*/
private[celeborn] class Inbox(
val endpointRef: NettyRpcEndpointRef,
val endpoint: RpcEndpoint)
extends Logging {
val endpoint: RpcEndpoint,
val conf: CelebornConf) extends Logging {

inbox => // Give this an alias so we can use it more clearly in closures.

private[netty] val capacity = conf.get(CelebornConf.RPC_INBOX_CAPACITY)

private[netty] val inboxLock = new ReentrantLock()
private[netty] val isFull = inboxLock.newCondition()

@GuardedBy("this")
protected val messages = new java.util.LinkedList[InboxMessage]()

private val messageCount = new AtomicLong(0)

/** True if the inbox (and its associated endpoint) is stopped. */
@GuardedBy("this")
private var stopped = false
Expand All @@ -85,84 +95,130 @@ private[celeborn] class Inbox(
private var numActiveThreads = 0

// OnStart should be the first message to process
inbox.synchronized {
try {
inboxLock.lockInterruptibly()
messages.add(OnStart)
messageCount.incrementAndGet()
} finally {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we need to increment messageCount after adding OnStart, because it will decrement when process the message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

inboxLock.unlock()
}

def addMessage(message: InboxMessage): Unit = {
messages.add(message)
messageCount.incrementAndGet()
signalNotFull()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to call signalNotFull() when add message? Seems we should only call when poll msg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a way for better efficiency? I saw LinkedBlockingQueue does signal this after putting an element in https://github.com/openjdk-mirror/jdk7u-jdk/blob/master/src/share/classes/java/util/concurrent/LinkedBlockingQueue.java#L354

logDebug(s"queue length of ${messageCount.get()} ")
}

private def processInternal(dispatcher: Dispatcher, message: InboxMessage): Unit = {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
content,
{ msg =>
throw new CelebornException(s"Unsupported message $message from ${_sender}")
})
} catch {
case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}

case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](
content,
{ msg =>
throw new CelebornException(s"Unsupported message $message from ${_sender}")
})

case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
try {
inboxLock.lockInterruptibly()
if (!stopped) {
enableConcurrent = true
}
} finally {
inboxLock.unlock()
}
}

case OnStop =>
val activeThreads =
try {
inboxLock.lockInterruptibly()
inbox.numActiveThreads
} finally {
inboxLock.unlock()
}
assert(
activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")

case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)

case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)

case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}

private[netty] def waitOnFull(): Unit = {
if (capacity > 0 && !stopped) {
try {
inboxLock.lockInterruptibly()
while (messageCount.get() >= capacity) {
isFull.await()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the while loop holds inboxLock and waits for isFull, however process first tries to lock then call signalNotFull, so seems there is deadlock here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, java's wait/await will automatically release its associated lock if it is called

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation :) I think we should use messageCount.get() >= capacity here. For example capacity is 100 and size is 99, two threads concurrently calls waitOnFull and both returns immediately, then both of them will add message, after which the capacity exceeds 100, and messageCount.get() == capacity always returns false.

}
} finally {
inboxLock.unlock()
}
}
}

private def signalNotFull(): Unit = {
// when this is called we assume putLock already being called
require(inboxLock.isHeldByCurrentThread, "cannot call signalNotFull without holding lock")
if (capacity > 0 && messageCount.get() < capacity) {
isFull.signal()
}
}

/**
* Process stored messages.
*/
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
try {
inboxLock.lockInterruptibly()
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
messageCount.decrementAndGet()
signalNotFull()
} else {
return
}
} finally {
inboxLock.unlock()
}

while (true) {
safelyCall(endpoint, endpointRef.name) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](
content,
{ msg =>
throw new CelebornException(s"Unsupported message $message from ${_sender}")
})
} catch {
case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}

case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](
content,
{ msg =>
throw new CelebornException(s"Unsupported message $message from ${_sender}")
})

case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}

case OnStop =>
val activeThreads = inbox.synchronized {
inbox.numActiveThreads
}
assert(
activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")

case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)

case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)

case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
processInternal(dispatcher, message)
}

inbox.synchronized {
try {
inboxLock.lockInterruptibly()
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
Expand All @@ -174,37 +230,56 @@ private[celeborn] class Inbox(
if (message == null) {
numActiveThreads -= 1
return
} else {
messageCount.decrementAndGet()
signalNotFull()
}
} finally {
inboxLock.unlock()
}
}
}

def post(message: InboxMessage): Unit = inbox.synchronized {
if (stopped) {
// We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
} else {
messages.add(message)
false
def post(message: InboxMessage): Unit = {
try {
inboxLock.lockInterruptibly()
if (stopped) {
// We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
} else {
addMessage(message)
}
} finally {
inboxLock.unlock()
}
}

def stop(): Unit = inbox.synchronized {
// The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
// message
if (!stopped) {
// We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
// thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
// safely.
enableConcurrent = false
stopped = true
messages.add(OnStop)
// Note: The concurrent events in messages will be processed one by one.
def stop(): Unit = {
try {
inboxLock.lockInterruptibly()
// The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
// message
if (!stopped) {
// We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
// thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
// safely.
enableConcurrent = false
stopped = true
addMessage(OnStop)
// Note: The concurrent events in messages will be processed one by one.
}
} finally {
inboxLock.unlock()
}
}

def isEmpty: Boolean = inbox.synchronized {
messages.isEmpty
def isEmpty: Boolean = {
try {
inboxLock.lockInterruptibly()
messages.isEmpty
} finally {
inboxLock.unlock()
}
}

/**
Expand All @@ -222,10 +297,13 @@ private[celeborn] class Inbox(
endpoint: RpcEndpoint,
endpointRefName: String)(action: => Unit): Unit = {
def dealWithFatalError(fatal: Throwable): Unit = {
inbox.synchronized {
try {
inboxLock.lockInterruptibly()
assert(numActiveThreads > 0, "The number of active threads should be positive.")
// Should reduce the number of active threads before throw the error.
numActiveThreads -= 1
} finally {
inboxLock.unlock()
}
logError(
s"An error happened while processing message in the inbox for $endpointRefName",
Expand Down Expand Up @@ -254,8 +332,11 @@ private[celeborn] class Inbox(

// exposed only for testing
def getNumActiveThreads: Int = {
inbox.synchronized {
try {
inboxLock.lockInterruptibly()
inbox.numActiveThreads
} finally {
inboxLock.unlock()
}
}
}
Loading