diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 18305ad3746a6..f63c52e930036 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -261,4 +261,7 @@ private[deploy] object DeployMessages { case object SendHeartbeat + // From LocalSparkCluster to Worker when stop() is called, in order to check whether + // the Worker is ready to stop or not. + case object IsWorkerReadyToStop } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index d057545afe2e1..d427c20edfa02 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -20,10 +20,11 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf +import org.apache.spark.deploy.DeployMessages.IsWorkerReadyToStop import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker import org.apache.spark.internal.{config, Logging} -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.util.Utils /** @@ -43,6 +44,7 @@ class LocalSparkCluster( private val localHostname = Utils.localHostName() private val masterRpcEnvs = ArrayBuffer[RpcEnv]() private val workerRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRefs = ArrayBuffer[RpcEndpointRef]() // exposed for testing var masterWebUIPort = -1 @@ -63,10 +65,11 @@ class LocalSparkCluster( /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masters, null, Some(workerNum), _conf, + val (workerEnv, workerRef) = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, + coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf, conf.get(config.Worker.SPARK_WORKER_RESOURCE_FILE)) workerRpcEnvs += workerEnv + workerRefs += workerRef } masters @@ -74,10 +77,15 @@ class LocalSparkCluster( def stop(): Unit = { logInfo("Shutting down local Spark cluster.") - // SPARK-31922: wait one more second before shutting down rpcEnvs of master and worker, - // in order to let the cluster have time to handle the `UnregisterApplication` message. + // SPARK-31922: make sure all the workers have handled the messages(`KillExecutor`, + // `ApplicationFinished`) from the Master before we shutdown the workers' rpcEnvs. // Otherwise, we could hit "RpcEnv already stopped" error. - Thread.sleep(1000) + var busyWorkers = workerRefs + while (busyWorkers.nonEmpty) { + Thread.sleep(100) + busyWorkers = busyWorkers.filter(_.askSync[Boolean](IsWorkerReadyToStop)) + } + // Stop the workers before the master so they don't get upset that it disconnected Seq(workerRpcEnvs, masterRpcEnvs).foreach { rpcEnvArr => rpcEnvArr.foreach { rpcEnv => Utils.tryLog { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index aa8c46fc68315..a3f45b17d4cba 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -153,6 +153,8 @@ private[deploy] class Worker( val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + // Used for `LocalSparkCluster` only + private var hasAppFinished = false val retainedExecutors = conf.get(WORKER_UI_RETAINED_EXECUTORS) val retainedDrivers = conf.get(WORKER_UI_RETAINED_DRIVERS) @@ -665,6 +667,7 @@ private[deploy] class Worker( reregisterWithMaster() case ApplicationFinished(id) => + hasAppFinished = true finishedApps += id maybeCleanupApplication(id) @@ -679,6 +682,9 @@ private[deploy] class Worker( finishedDrivers.values.toList, activeMasterUrl, cores, memory, coresUsed, memoryUsed, activeMasterWebUiUrl, resources, resourcesUsed.toMap.map { case (k, v) => (k, v.toResourceInformation)})) + + case IsWorkerReadyToStop => + context.reply(executors.isEmpty && hasAppFinished) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -852,7 +858,7 @@ private[deploy] object Worker extends Logging { val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir, conf = conf, - resourceFileOpt = conf.get(SPARK_WORKER_RESOURCE_FILE)) + resourceFileOpt = conf.get(SPARK_WORKER_RESOURCE_FILE))._1 // With external shuffle service enabled, if we request to launch multiple workers on one host, // we can only successfully launch the first worker and the rest fails, because with the port // bound, we may launch no more than one external shuffle service on each host. @@ -877,16 +883,16 @@ private[deploy] object Worker extends Logging { workDir: String, workerNumber: Option[Int] = None, conf: SparkConf = new SparkConf, - resourceFileOpt: Option[String] = None): RpcEnv = { + resourceFileOpt: Option[String] = None): (RpcEnv, RpcEndpointRef) = { // The LocalSparkCluster runs multiple local sparkWorkerX RPC Environments val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL) - rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, + val workerRef = rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr, resourceFileOpt)) - rpcEnv + (rpcEnv, workerRef) } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = {