Skip to content

Commit

Permalink
[CORE] Fix regressions in decommissioning
Browse files Browse the repository at this point in the history
The DecommissionWorkerSuite started becoming flaky and it revealed a real regression. Recent PR's (#28085 and #29211) neccessitate a small reworking of the decommissioning logic.

Before getting into that, let me describe the intended behavior of decommissioning:

If a fetch failure happens where the source executor was decommissioned, we want to treat that as an eager signal to clear all shuffle state associated with that executor. In addition if we know that the host was decommissioned, we want to forget about all map statuses from all other executors on that decommissioned host. This is what the test "decommission workers ensure that fetch failures lead to rerun" is trying to test. This invariant is important to ensure that decommissioning a host does not lead to multiple fetch failures that might fail the job.

- Per #29211, the executors now eagerly exit on decommissioning and thus the executor is lost before the fetch failure even happens. (I tested this by waiting some seconds before triggering the fetch failure). When an executor is lost, we forget its decommissioning information. The fix is to keep the decommissioning information around for some time after removal with some extra logic to finally purge it after a timeout.

- Per #28085, when the executor is lost, it forgets the shuffle state about just that executor and increments the shuffleFileLostEpoch. This incrementing precludes the clearing of state of the entire host when the fetch failure happens. I elected to only change this codepath for the special case of decommissioning, without any other side effects. This whole version keeping stuff is complex and it has effectively not been semantically changed since 2013! The fix here is also simple: Ignore the shuffleFileLostEpoch when the shuffle status is being cleared due to a fetch failure resulting from host decommission.

These two fixes are local to decommissioning only and don't change other behavior.

I also added some more tests to TaskSchedulerImpl to ensure that the decommissioning information is indeed purged after a timeout.
  • Loading branch information
dagrawal3409 committed Aug 13, 2020
1 parent 0c850c7 commit b890443
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 34 deletions.
33 changes: 21 additions & 12 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1846,7 +1846,8 @@ private[spark] class DAGScheduler(
execId = bmAddress.executorId,
fileLost = true,
hostToUnregisterOutputs = hostToUnregisterOutputs,
maybeEpoch = Some(task.epoch))
maybeEpoch = Some(task.epoch),
ignoreShuffleVersion = isHostDecommissioned)
}
}

Expand Down Expand Up @@ -2012,7 +2013,8 @@ private[spark] class DAGScheduler(
execId: String,
fileLost: Boolean,
hostToUnregisterOutputs: Option[String],
maybeEpoch: Option[Long] = None): Unit = {
maybeEpoch: Option[Long] = None,
ignoreShuffleVersion: Boolean = false): Unit = {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
logDebug(s"Considering removal of executor $execId; " +
s"fileLost: $fileLost, currentEpoch: $currentEpoch")
Expand All @@ -2022,16 +2024,23 @@ private[spark] class DAGScheduler(
blockManagerMaster.removeExecutor(execId)
clearCacheLocs()
}
if (fileLost &&
(!shuffleFileLostEpoch.contains(execId) || shuffleFileLostEpoch(execId) < currentEpoch)) {
shuffleFileLostEpoch(execId) = currentEpoch
hostToUnregisterOutputs match {
case Some(host) =>
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnExecutor(execId)
if (fileLost) {
val remove = if (!shuffleFileLostEpoch.contains(execId) ||
shuffleFileLostEpoch(execId) < currentEpoch) {
shuffleFileLostEpoch(execId) = currentEpoch
true
} else {
ignoreShuffleVersion
}
if (remove) {
hostToUnregisterOutputs match {
case Some(host) =>
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnExecutor(execId)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util
import java.util.{Timer, TimerTask}
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong
Expand Down Expand Up @@ -136,7 +137,9 @@ private[spark] class TaskSchedulerImpl(
// IDs of the tasks running on each executor
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]

private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
// map of second to list of executors to clear form the above map
val decommissioningExecutorsToGc = new util.TreeMap[Long, mutable.ArrayBuffer[String]]()

def runningTasksByExecutors: Map[String, Int] = synchronized {
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
Expand Down Expand Up @@ -910,7 +913,7 @@ private[spark] class TaskSchedulerImpl(
// if we heard isHostDecommissioned ever true, then we keep that one since it is
// most likely coming from the cluster manager and thus authoritative
val oldDecomInfo = executorsPendingDecommission.get(executorId)
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
executorsPendingDecommission(executorId) = decommissionInfo
}
}
Expand All @@ -921,7 +924,13 @@ private[spark] class TaskSchedulerImpl(

override def getExecutorDecommissionInfo(executorId: String)
: Option[ExecutorDecommissionInfo] = synchronized {
executorsPendingDecommission.get(executorId)
import scala.collection.JavaConverters._
// Garbage collect old decommissioning entries
val secondsToGcUptil = TimeUnit.MILLISECONDS.toSeconds(clock.getTimeMillis())
val headMap = decommissioningExecutorsToGc.headMap(secondsToGcUptil)
headMap.values().asScala.flatten.foreach(executorsPendingDecommission -= _)
headMap.clear()
executorsPendingDecommission.get(executorId)
}

override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
Expand Down Expand Up @@ -1027,7 +1036,15 @@ private[spark] class TaskSchedulerImpl(
}
}

executorsPendingDecommission -= executorId

val decomInfo = executorsPendingDecommission.get(executorId)
if (decomInfo.isDefined) {
val rememberSeconds =
conf.getInt("spark.decommissioningRememberAfterRemoval.seconds", 60)
val gcSecond = TimeUnit.MILLISECONDS.toSeconds(clock.getTimeMillis()) + rememberSeconds
decommissioningExecutorsToGc.computeIfAbsent(gcSecond, _ => mutable.ArrayBuffer.empty) +=
executorId
}

if (reason != LossReasonPending) {
executorIdToHost -= executorId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ class DecommissionWorkerSuite
}
}

// Unlike TestUtils.withListener, it also waits for the job to be done
def withListener(sc: SparkContext, listener: RootStageAwareListener)
(body: SparkListener => Unit): Unit = {
sc.addSparkListener(listener)
try {
body(listener)
sc.listenerBus.waitUntilEmpty()
listener.waitForJobDone()
} finally {
sc.listenerBus.removeListener(listener)
}
}

test("decommission workers should not result in job failure") {
val maxTaskFailures = 2
val numTimesToKillWorkers = maxTaskFailures + 1
Expand All @@ -109,7 +122,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
Thread.sleep(5 * 1000L); 1
}.count()
Expand Down Expand Up @@ -164,7 +177,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
val sleepTimeSeconds = if (pid == 0) 1 else 10
Thread.sleep(sleepTimeSeconds * 1000L)
Expand Down Expand Up @@ -212,22 +225,27 @@ class DecommissionWorkerSuite
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskInfo = taskEnd.taskInfo
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
taskEnd.stageAttemptId == 0) {
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
decommissionWorkerOnMaster(workerToDecom,
"decommission worker after task on it is done")
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
val executorId = SparkEnv.get.executorId
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
val context = TaskContext.get()
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
}
List(1).iterator
}, preservesPartitioning = true)
.repartition(1).mapPartitions(iter => {
val context = TaskContext.get()
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
// Wait a bit for the decommissioning to be triggered in the listener
Thread.sleep(5000)
// MapIndex is explicitly -1 to force the entire host to be decommissioned
// However, this will cause both the tasks in the preceding stage since the host here is
// "localhost" (shortcoming of this single-machine unit test in that all the workers
Expand Down Expand Up @@ -265,23 +283,31 @@ class DecommissionWorkerSuite
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEnd.jobResult match {
case JobSucceeded => jobDone.set(true)
case JobFailed(exception) => logError(s"Job failed", exception)
}
}

protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}

protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}

private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
String = {
s"${stageId}:${stageAttemptId}:" +
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
}

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
logInfo(s"Task started: $signature")
if (isRootStageId(taskStart.stageId)) {
rootTasksStarted.add(taskStart.taskInfo)
handleRootTaskStart(taskStart)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
logInfo(s"Task End $taskSignature")
tasksFinished.add(taskSignature)
if (isRootStageId(taskEnd.stageId)) {
Expand All @@ -291,8 +317,13 @@ class DecommissionWorkerSuite
}

def getTasksFinished(): Seq[String] = {
assert(jobDone.get(), "Job isn't successfully done yet")
tasksFinished.asScala.toSeq
tasksFinished.asScala.toList
}

def waitForJobDone(): Unit = {
eventually(timeout(10.seconds), interval(100.milliseconds)) {
assert(jobDone.get(), "Job isn't successfully done yet")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.internal.config
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, TaskResourceRequests}
import org.apache.spark.resource.ResourceUtils._
import org.apache.spark.resource.TestResourceIDs._
import org.apache.spark.util.ManualClock
import org.apache.spark.util.{Clock, ManualClock, SystemClock}

class FakeSchedulerBackend extends SchedulerBackend {
def start(): Unit = {}
Expand Down Expand Up @@ -88,10 +88,15 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}

def setupSchedulerWithMaster(master: String, confs: (String, String)*): TaskSchedulerImpl = {
setupSchedulerWithMasterAndClock(master, new SystemClock, confs: _*)
}

def setupSchedulerWithMasterAndClock(master: String, clock: Clock, confs: (String, String)*):
TaskSchedulerImpl = {
val conf = new SparkConf().setMaster(master).setAppName("TaskSchedulerImplSuite")
confs.foreach { case (k, v) => conf.set(k, v) }
sc = new SparkContext(conf)
taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler = new TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES), clock = clock)
setupHelper()
}

Expand Down Expand Up @@ -1802,9 +1807,10 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(2 == taskDescriptions.head.resources(GPU).addresses.size)
}

private def setupSchedulerForDecommissionTests(): TaskSchedulerImpl = {
val taskScheduler = setupSchedulerWithMaster(
private def setupSchedulerForDecommissionTests(clock: Clock): TaskSchedulerImpl = {
val taskScheduler = setupSchedulerWithMasterAndClock(
s"local[2]",
clock,
config.CPUS_PER_TASK.key -> 1.toString)
taskScheduler.submitTasks(FakeTask.createTaskSet(2))
val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1),
Expand All @@ -1815,7 +1821,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}

test("scheduler should keep the decommission info where host was decommissioned") {
val scheduler = setupSchedulerForDecommissionTests()
val scheduler = setupSchedulerForDecommissionTests(new SystemClock)

scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false))
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true))
Expand All @@ -1829,8 +1835,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty)
}

test("scheduler should ignore decommissioning of removed executors") {
val scheduler = setupSchedulerForDecommissionTests()
test("scheduler should eventually purge removed and decommissioned executors") {
val clock = new ManualClock(10000L)
val scheduler = setupSchedulerForDecommissionTests(clock)

// executor 0 is decommissioned after loosing
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
Expand All @@ -1839,14 +1846,29 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false))
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)

assert(scheduler.executorsPendingDecommission.isEmpty)
clock.advance(5000)

// executor 1 is decommissioned before loosing
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
clock.advance(2000)
scheduler.executorLost("executor1", ExecutorExited(0, false, "normal"))
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
assert(scheduler.decommissioningExecutorsToGc.size === 1)
assert(scheduler.executorsPendingDecommission.size === 1)
clock.advance(2000)
// It hasn't been 60 seconds yet before removal
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
clock.advance(2000)
assert(scheduler.decommissioningExecutorsToGc.size === 1)
assert(scheduler.executorsPendingDecommission.size === 1)
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
clock.advance(61000)
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
assert(scheduler.decommissioningExecutorsToGc.isEmpty)
assert(scheduler.executorsPendingDecommission.isEmpty)
}

/**
Expand Down

0 comments on commit b890443

Please sign in to comment.