Skip to content

Commit

Permalink
[SPARK-21408][CORE] Better default number of RPC dispatch threads.
Browse files Browse the repository at this point in the history
Instead of using the host's cpu count, use the number of cores allocated
for the Spark process when sizing the RPC dispatch thread pool. This avoids
creating large thread pools on large machines when the number of allocated
cores is small.

Tested by verifying number of threads with spark.executor.cores set
to 1 and 4; same thing for YARN AM.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #18639 from vanzin/SPARK-21408.
  • Loading branch information
Marcelo Vanzin committed Jul 18, 2017
1 parent cde64ad commit 264b0f3
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ object SparkEnv extends Logging {

val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
securityManager, clientMode = !isDriver)
securityManager, numUsableCores, !isDriver)

// Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
if (isDriver) {
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[spark] object RpcEnv {
conf: SparkConf,
securityManager: SecurityManager,
clientMode: Boolean = false): RpcEnv = {
create(name, host, host, port, conf, securityManager, clientMode)
create(name, host, host, port, conf, securityManager, 0, clientMode)
}

def create(
Expand All @@ -50,9 +50,10 @@ private[spark] object RpcEnv {
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean): RpcEnv = {
val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
clientMode)
numUsableCores, clientMode)
new NettyRpcEnvFactory().create(config)
}
}
Expand Down Expand Up @@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig(
advertiseAddress: String,
port: Int,
securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils

/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
*
* @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool.
* If 0, will consider the available CPUs on the host.
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {

private class EndpointData(
val name: String,
Expand Down Expand Up @@ -192,8 +195,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {

/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, Runtime.getRuntime.availableProcessors()))
math.max(2, availableCores))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv(
val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance,
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
securityManager: SecurityManager,
numUsableCores: Int) extends RpcEnv(conf) with Logging {

private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))

private val dispatcher: Dispatcher = new Dispatcher(this)
private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)

private val streamManager = new NettyStreamManager(this)

Expand Down Expand Up @@ -451,7 +452,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
config.securityManager)
config.securityManager, config.numUsableCores)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(config.bindAddress, actualPort)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
port: Int,
clientMode: Boolean = false): RpcEnv = {
val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port,
new SecurityManager(conf), clientMode)
new SecurityManager(conf), 0, clientMode)
new NettyRpcEnvFactory().create(config)
}

Expand All @@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
test("advertise address different from bind address") {
val sparkConf = new SparkConf()
val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0,
new SecurityManager(sparkConf), false)
new SecurityManager(sparkConf), 0, false)
val env = new NettyRpcEnvFactory().create(config)
try {
assert(env.address.hostPort.startsWith("example.com:"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,10 @@ private[spark] class ApplicationMaster(
}

private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr,
clientMode = true)
val hostname = Utils.localHostName
val amCores = sparkConf.get(AM_CORES)
rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),
Expand Down

0 comments on commit 264b0f3

Please sign in to comment.