Skip to content

Commit

Permalink
[SPARK-32613][CORE] Fix regressions in DecommissionWorkerSuite
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The DecommissionWorkerSuite started becoming flaky and it revealed a real regression. Recently closed #29211 necessitates remembering the decommissioning shortly beyond the removal of the executor.

In addition to fixing this issue, ensure that DecommissionWorkerSuite continues to pass when executors haven't had a chance to exit eagery. That is the old behavior before #29211 also still works.

Added some more tests to TaskSchedulerImpl to ensure that the decommissioning information is indeed purged after a timeout.

Hardened the test DecommissionWorkerSuite to make it wait for successful job completion.

### Why are the changes needed?

First, let me describe the intended behavior of decommissioning: If a fetch failure happens where the source executor was decommissioned, we want to treat that as an eager signal to clear all shuffle state associated with that executor. In addition if we know that the host was decommissioned, we want to forget about all map statuses from all other executors on that decommissioned host. This is what the test "decommission workers ensure that fetch failures lead to rerun" is trying to test. This invariant is important to ensure that decommissioning a host does not lead to multiple fetch failures that might fail the job. This fetch failure can happen before the executor is truly marked "lost" because of heartbeat delays.

- However, #29211 eagerly exits the executors when they are done decommissioning. This removal of the executor was racing with the fetch failure. By the time the fetch failure is triggered the executor is already removed and thus has forgotten its decommissioning information. (I tested this by delaying the decommissioning). The fix is to keep the decommissioning information around for some time after removal with some extra logic to finally purge it after a timeout.

- In addition the executor loss can also bump up `shuffleFileLostEpoch` (added in #28848). This happens because when the executor is lost, it forgets the shuffle state about just that executor and increments the `shuffleFileLostEpoch`. This incrementing precludes the clearing of state of the entire host when the fetch failure happens because the failed task is still reusing the old epoch. The fix here is also simple: Ignore the `shuffleFileLostEpoch` when the shuffle status is being cleared due to a fetch failure resulting from host decommission.

I am strategically making both of these fixes be very local to decommissioning to avoid other regressions. Especially the version stuff is tricky (it hasn't been fundamentally changed since it was first introduced in 2013).

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Manually ran DecommissionWorkerSuite several times using a script and ensured it all passed.

### (Internal) Configs added
I added two configs, one of which is sort of meant for testing only:
- `spark.test.executor.decommission.initial.sleep.millis`: Initial delay by the decommissioner shutdown thread. Default is same as before of 1 second. This is used for testing only. This one is kept "hidden" (ie not added as a constant to avoid config bloat)
- `spark.executor.decommission.removed.infoCacheTTL`: Number of seconds to keep the removed executors decom entries around. It defaults to 5 minutes. It should be around the average time it takes for all of the shuffle data to be fetched from the mapper to the reducer, but I think that can take a while since the reducers also do a multistep sort.

Closes #29422 from agrawaldevesh/decom_fixes.

Authored-by: Devesh Agrawal <devesh.agrawal@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
dagrawal3409 authored and cloud-fan committed Aug 18, 2020
1 parent b33066f commit 1ac23de
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,15 @@ private[spark] class CoarseGrainedExecutorBackend(
override def run(): Unit = {
var lastTaskRunningTime = System.nanoTime()
val sleep_time = 1000 // 1s

// This config is internal and only used by unit tests to force an executor
// to hang around for longer when decommissioned.
val initialSleepMillis = env.conf.getInt(
"spark.test.executor.decommission.initial.sleep.millis", sleep_time)
if (initialSleepMillis > 0) {
Thread.sleep(initialSleepMillis)
}
while (true) {
logInfo("Checking to see if we can shutdown.")
Thread.sleep(sleep_time)
if (executor == null || executor.numRunningTasks == 0) {
if (env.conf.get(STORAGE_DECOMMISSION_ENABLED)) {
logInfo("No running tasks, checking migrations")
Expand All @@ -323,6 +328,7 @@ private[spark] class CoarseGrainedExecutorBackend(
// move forward.
lastTaskRunningTime = System.nanoTime()
}
Thread.sleep(sleep_time)
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,16 @@ package object config {
.timeConf(TimeUnit.SECONDS)
.createOptional

private[spark] val DECOMMISSIONED_EXECUTORS_REMEMBER_AFTER_REMOVAL_TTL =
ConfigBuilder("spark.executor.decommission.removed.infoCacheTTL")
.doc("Duration for which a decommissioned executor's information will be kept after its" +
"removal. Keeping the decommissioned info after removal helps pinpoint fetch failures to " +
"decommissioning even after the mapper executor has been decommissioned. This allows " +
"eager recovery from fetch failures caused by decommissioning, increasing job robustness.")
.version("3.1.0")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("5m")

private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir")
.doc("Staging directory used while submitting applications.")
.version("2.0.0")
Expand Down
41 changes: 29 additions & 12 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1846,7 +1846,14 @@ private[spark] class DAGScheduler(
execId = bmAddress.executorId,
fileLost = true,
hostToUnregisterOutputs = hostToUnregisterOutputs,
maybeEpoch = Some(task.epoch))
maybeEpoch = Some(task.epoch),
// shuffleFileLostEpoch is ignored when a host is decommissioned because some
// decommissioned executors on that host might have been removed before this fetch
// failure and might have bumped up the shuffleFileLostEpoch. We ignore that, and
// proceed with unconditional removal of shuffle outputs from all executors on that
// host, including from those that we still haven't confirmed as lost due to heartbeat
// delays.
ignoreShuffleFileLostEpoch = isHostDecommissioned)
}
}

Expand Down Expand Up @@ -2012,7 +2019,8 @@ private[spark] class DAGScheduler(
execId: String,
fileLost: Boolean,
hostToUnregisterOutputs: Option[String],
maybeEpoch: Option[Long] = None): Unit = {
maybeEpoch: Option[Long] = None,
ignoreShuffleFileLostEpoch: Boolean = false): Unit = {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
logDebug(s"Considering removal of executor $execId; " +
s"fileLost: $fileLost, currentEpoch: $currentEpoch")
Expand All @@ -2022,16 +2030,25 @@ private[spark] class DAGScheduler(
blockManagerMaster.removeExecutor(execId)
clearCacheLocs()
}
if (fileLost &&
(!shuffleFileLostEpoch.contains(execId) || shuffleFileLostEpoch(execId) < currentEpoch)) {
shuffleFileLostEpoch(execId) = currentEpoch
hostToUnregisterOutputs match {
case Some(host) =>
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnExecutor(execId)
if (fileLost) {
val remove = if (ignoreShuffleFileLostEpoch) {
true
} else if (!shuffleFileLostEpoch.contains(execId) ||
shuffleFileLostEpoch(execId) < currentEpoch) {
shuffleFileLostEpoch(execId) = currentEpoch
true
} else {
false
}
if (remove) {
hostToUnregisterOutputs match {
case Some(host) =>
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
mapOutputTracker.removeOutputsOnExecutor(execId)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap, HashSet}
import scala.util.Random

import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.ExecutorMetrics
Expand Down Expand Up @@ -136,7 +139,21 @@ private[spark] class TaskSchedulerImpl(
// IDs of the tasks running on each executor
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]

private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
// We add executors here when we first get decommission notification for them. Executors can
// continue to run even after being asked to decommission, but they will eventually exit.
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]

// When they exit and we know of that via heartbeat failure, we will add them to this cache.
// This cache is consulted to know if a fetch failure is because a source executor was
// decommissioned.
lazy val decommissionedExecutorsRemoved = CacheBuilder.newBuilder()
.expireAfterWrite(
conf.get(DECOMMISSIONED_EXECUTORS_REMEMBER_AFTER_REMOVAL_TTL), TimeUnit.SECONDS)
.ticker(new Ticker{
override def read(): Long = TimeUnit.MILLISECONDS.toNanos(clock.getTimeMillis())
})
.build[String, ExecutorDecommissionInfo]()
.asMap()

def runningTasksByExecutors: Map[String, Int] = synchronized {
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
Expand Down Expand Up @@ -910,7 +927,7 @@ private[spark] class TaskSchedulerImpl(
// if we heard isHostDecommissioned ever true, then we keep that one since it is
// most likely coming from the cluster manager and thus authoritative
val oldDecomInfo = executorsPendingDecommission.get(executorId)
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
executorsPendingDecommission(executorId) = decommissionInfo
}
}
Expand All @@ -921,7 +938,9 @@ private[spark] class TaskSchedulerImpl(

override def getExecutorDecommissionInfo(executorId: String)
: Option[ExecutorDecommissionInfo] = synchronized {
executorsPendingDecommission.get(executorId)
executorsPendingDecommission
.get(executorId)
.orElse(Option(decommissionedExecutorsRemoved.get(executorId)))
}

override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
Expand Down Expand Up @@ -1027,7 +1046,9 @@ private[spark] class TaskSchedulerImpl(
}
}

executorsPendingDecommission -= executorId

val decomInfo = executorsPendingDecommission.remove(executorId)
decomInfo.foreach(decommissionedExecutorsRemoved.put(executorId, _))

if (reason != LossReasonPending) {
executorIdToHost -= executorId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ class DecommissionWorkerSuite
}
}

// Unlike TestUtils.withListener, it also waits for the job to be done
def withListener(sc: SparkContext, listener: RootStageAwareListener)
(body: SparkListener => Unit): Unit = {
sc.addSparkListener(listener)
try {
body(listener)
sc.listenerBus.waitUntilEmpty()
listener.waitForJobDone()
} finally {
sc.listenerBus.removeListener(listener)
}
}

test("decommission workers should not result in job failure") {
val maxTaskFailures = 2
val numTimesToKillWorkers = maxTaskFailures + 1
Expand All @@ -109,7 +122,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
Thread.sleep(5 * 1000L); 1
}.count()
Expand Down Expand Up @@ -164,7 +177,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
val sleepTimeSeconds = if (pid == 0) 1 else 10
Thread.sleep(sleepTimeSeconds * 1000L)
Expand All @@ -190,10 +203,11 @@ class DecommissionWorkerSuite
}
}

test("decommission workers ensure that fetch failures lead to rerun") {
def testFetchFailures(initialSleepMillis: Int): Unit = {
createWorkers(2)
sc = createSparkContext(
config.Tests.TEST_NO_STAGE_RETRY.key -> "false",
"spark.test.executor.decommission.initial.sleep.millis" -> initialSleepMillis.toString,
config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true")
val executorIdToWorkerInfo = getExecutorToWorkerAssignments
val executorToDecom = executorIdToWorkerInfo.keysIterator.next
Expand All @@ -212,22 +226,29 @@ class DecommissionWorkerSuite
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskInfo = taskEnd.taskInfo
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
taskEnd.stageAttemptId == 0) {
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
decommissionWorkerOnMaster(workerToDecom,
"decommission worker after task on it is done")
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
val executorId = SparkEnv.get.executorId
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
val context = TaskContext.get()
// Only sleep in the first attempt to create the required window for decommissioning.
// Subsequent attempts don't need to be delayed to speed up the test.
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
}
List(1).iterator
}, preservesPartitioning = true)
.repartition(1).mapPartitions(iter => {
val context = TaskContext.get()
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
// Wait a bit for the decommissioning to be triggered in the listener
Thread.sleep(5000)
// MapIndex is explicitly -1 to force the entire host to be decommissioned
// However, this will cause both the tasks in the preceding stage since the host here is
// "localhost" (shortcoming of this single-machine unit test in that all the workers
Expand All @@ -246,6 +267,14 @@ class DecommissionWorkerSuite
assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen")
}

test("decommission stalled workers ensure that fetch failures lead to rerun") {
testFetchFailures(3600 * 1000)
}

test("decommission eager workers ensure that fetch failures lead to rerun") {
testFetchFailures(0)
}

private abstract class RootStageAwareListener extends SparkListener {
private var rootStageId: Option[Int] = None
private val tasksFinished = new ConcurrentLinkedQueue[String]()
Expand All @@ -265,23 +294,31 @@ class DecommissionWorkerSuite
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEnd.jobResult match {
case JobSucceeded => jobDone.set(true)
case JobFailed(exception) => logError(s"Job failed", exception)
}
}

protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}

protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}

private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
String = {
s"${stageId}:${stageAttemptId}:" +
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
}

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
logInfo(s"Task started: $signature")
if (isRootStageId(taskStart.stageId)) {
rootTasksStarted.add(taskStart.taskInfo)
handleRootTaskStart(taskStart)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
logInfo(s"Task End $taskSignature")
tasksFinished.add(taskSignature)
if (isRootStageId(taskEnd.stageId)) {
Expand All @@ -291,8 +328,13 @@ class DecommissionWorkerSuite
}

def getTasksFinished(): Seq[String] = {
assert(jobDone.get(), "Job isn't successfully done yet")
tasksFinished.asScala.toSeq
tasksFinished.asScala.toList
}

def waitForJobDone(): Unit = {
eventually(timeout(10.seconds), interval(100.milliseconds)) {
assert(jobDone.get(), "Job isn't successfully done yet")
}
}
}

Expand Down
Loading

0 comments on commit 1ac23de

Please sign in to comment.