diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 3e78c7ae240f3..34c0696bfc4e5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -64,6 +64,16 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) + /** + * A worker will send this message to the master when it registers with the master. Then the + * master will compare them with the executors and drivers in the master and tell the worker to + * kill the unknown executors and drivers. + */ + case class WorkerLatestState( + id: String, + executors: Seq[ExecutorDescription], + driverIds: Seq[String]) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index ff8d29fdb4e16..6b9b1408ee44e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -368,6 +368,30 @@ private[deploy] class Master( if (canCompleteRecovery) { completeRecovery() } } + case WorkerLatestState(workerId, executors, driverIds) => + idToWorker.get(workerId) match { + case Some(worker) => + for (exec <- executors) { + val executorMatches = worker.executors.exists { + case (_, e) => e.application.id == exec.appId && e.id == exec.execId + } + if (!executorMatches) { + // master doesn't recognize this executor. So just tell worker to kill it. + worker.endpoint.send(KillExecutor(masterUrl, exec.appId, exec.execId)) + } + } + + for (driverId <- driverIds) { + val driverMatches = worker.drivers.exists { case (id, _) => id == driverId } + if (!driverMatches) { + // master doesn't recognize this driver. So just tell worker to kill it. + worker.endpoint.send(KillDriver(driverId)) + } + } + case None => + logWarning("Worker state from unknown worker: " + workerId) + } + case UnregisterApplication(applicationId) => logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 283db6c4fe8d5..c18c8c7c8603a 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -374,6 +374,11 @@ private[deploy] class Worker( }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } + val execs = executors.values.map { e => + new ExecutorDescription(e.appId, e.execId, e.cores, e.state) + } + masterRef.send(WorkerLatestState(workerId, execs.toList, drivers.keys.toSeq)) + case RegisterWorkerFailed(message) => if (!registered) { logError("Worker registration failed: " + message) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index ce00807ea46b9..7cbe4e342eaa5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -18,22 +18,36 @@ package org.apache.spark.deploy.master import java.util.Date +import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.{Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} -class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester { +class MasterSuite extends SparkFunSuite + with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { + + private var _master: Master = _ + + after { + if (_master != null) { + _master.rpcEnv.shutdown() + _master.rpcEnv.awaitTermination() + _master = null + } + } test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) @@ -357,10 +371,11 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private val workerInfos = Array(workerInfo, workerInfo, workerInfo) private def makeMaster(conf: SparkConf = new SparkConf): Master = { + assert(_master === null, "Some Master's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) - val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) - master + _master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + _master } private def makeAppInfo( @@ -386,4 +401,35 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva master.invokePrivate(_scheduleExecutorsOnWorkers(appInfo, workerInfos, spreadOut)) } + test("SPARK-13604: Master should ask Worker kill unknown executors and drivers") { + val master = makeMaster() + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askWithRetry[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + + val killedExecutors = new ConcurrentLinkedQueue[(String, Int)]() + val killedDrivers = new ConcurrentLinkedQueue[String]() + val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = master.rpcEnv + + override def receive: PartialFunction[Any, Unit] = { + case KillExecutor(_, appId, execId) => killedExecutors.add(appId, execId) + case KillDriver(driverId) => killedDrivers.add(driverId) + } + }) + + master.self.ask( + RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) + val executors = (0 until 3).map { i => + new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) + } + master.self.send(WorkerLatestState("1", executors, driverIds = Seq("0", "1", "2"))) + + eventually(timeout(10.seconds)) { + assert(killedExecutors.asScala.toList.sorted === List("0" -> 0, "1" -> 1, "2" -> 2)) + assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2")) + } + } }