Skip to content

Commit

Permalink
changes to the test and the logic: ignore fetch failures to abort on …
Browse files Browse the repository at this point in the history
…decom hosts and forcefully handle fetch failure
  • Loading branch information
dagrawal3409 committed Aug 13, 2020
1 parent 0c850c7 commit aaacf30
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 18 deletions.
23 changes: 17 additions & 6 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,11 @@ private[spark] class DAGScheduler(
case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleIdToMapStage(shuffleId)
val sourceDecommissioned = if (bmAddress != null && bmAddress.executorId != null) {
taskScheduler.getExecutorDecommissionInfo(bmAddress.executorId)
} else {
None
}

if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
Expand All @@ -1675,7 +1680,8 @@ private[spark] class DAGScheduler(
} else {
failedStage.failedAttemptIds.add(task.stageAttemptId)
val shouldAbortStage =
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
sourceDecommissioned.isEmpty &&
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
disallowStageRetryForTest

// It is likely that we receive multiple FetchFailed for a single stage (because we have
Expand Down Expand Up @@ -1824,16 +1830,14 @@ private[spark] class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled
val isHostDecommissioned = taskScheduler
.getExecutorDecommissionInfo(bmAddress.executorId)
.exists(_.isHostDecommissioned)
val sourceHostDecommissioned = sourceDecommissioned.exists(_.isHostDecommissioned)

// Shuffle output of all executors on host `bmAddress.host` may be lost if:
// - External shuffle service is enabled, so we assume that all shuffle data on node is
// bad.
// - Host is decommissioned, thus all executors on that host will die.
val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled ||
isHostDecommissioned
sourceHostDecommissioned
val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost
&& unRegisterOutputOnHostOnFetchFailure) {
Some(bmAddress.host)
Expand All @@ -1842,11 +1846,18 @@ private[spark] class DAGScheduler(
// reason to believe shuffle data has been lost for the entire host).
None
}
val maybeEpoch = if (sourceHostDecommissioned) {
// If we know that the host has been decommissioned, remove its map outputs
// unconditionally
None
} else {
Some(task.epoch)
}
removeExecutorAndUnregisterOutputs(
execId = bmAddress.executorId,
fileLost = true,
hostToUnregisterOutputs = hostToUnregisterOutputs,
maybeEpoch = Some(task.epoch))
maybeEpoch)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util
import java.util.{Timer, TimerTask}
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong
Expand Down Expand Up @@ -137,6 +138,8 @@ private[spark] class TaskSchedulerImpl(
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]

private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
// map of second to list of executors to clear form the above map
private val decommissioningExecutorsToGc = new util.TreeMap[Long, mutable.ArrayBuffer[String]]()

def runningTasksByExecutors: Map[String, Int] = synchronized {
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
Expand Down Expand Up @@ -921,7 +924,13 @@ private[spark] class TaskSchedulerImpl(

override def getExecutorDecommissionInfo(executorId: String)
: Option[ExecutorDecommissionInfo] = synchronized {
executorsPendingDecommission.get(executorId)
import scala.collection.JavaConverters._
// Garbage collect old decommissioning entries
val secondToGcUptil = math.floor(clock.getTimeMillis() / 1000.0).toLong
val headMap = decommissioningExecutorsToGc.headMap(secondToGcUptil)
headMap.values().asScala.flatten.foreach(executorsPendingDecommission -= _)
headMap.clear()
executorsPendingDecommission.get(executorId)
}

override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
Expand Down Expand Up @@ -1027,7 +1036,13 @@ private[spark] class TaskSchedulerImpl(
}
}

executorsPendingDecommission -= executorId

val decomInfo = executorsPendingDecommission.get(executorId)
if (decomInfo.isDefined) {
// TODO(dagrawal): make this timestamp configurable
val gcSecond = math.ceil(clock.getTimeMillis() / 1000.0).toLong + 60
decommissioningExecutorsToGc.getOrDefault(gcSecond, mutable.ArrayBuffer.empty) += executorId
}

if (reason != LossReasonPending) {
executorIdToHost -= executorId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ class DecommissionWorkerSuite
}
}

// Unlike TestUtils.withListener, it also waits for the job to be done
def withListener(sc: SparkContext, listener: RootStageAwareListener)
(body: SparkListener => Unit): Unit = {
sc.addSparkListener(listener)
try {
body(listener)
sc.listenerBus.waitUntilEmpty()
listener.waitForJobDone()
} finally {
sc.listenerBus.removeListener(listener)
}
}

test("decommission workers should not result in job failure") {
val maxTaskFailures = 2
val numTimesToKillWorkers = maxTaskFailures + 1
Expand All @@ -109,7 +122,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
Thread.sleep(5 * 1000L); 1
}.count()
Expand Down Expand Up @@ -164,7 +177,7 @@ class DecommissionWorkerSuite
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
val sleepTimeSeconds = if (pid == 0) 1 else 10
Thread.sleep(sleepTimeSeconds * 1000L)
Expand Down Expand Up @@ -212,22 +225,27 @@ class DecommissionWorkerSuite
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskInfo = taskEnd.taskInfo
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
taskEnd.stageAttemptId == 0) {
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
decommissionWorkerOnMaster(workerToDecom,
"decommission worker after task on it is done")
}
}
}
TestUtils.withListener(sc, listener) { _ =>
withListener(sc, listener) { _ =>
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
val executorId = SparkEnv.get.executorId
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
val context = TaskContext.get()
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
Thread.sleep(sleepTimeSeconds * 1000L)
}
List(1).iterator
}, preservesPartitioning = true)
.repartition(1).mapPartitions(iter => {
val context = TaskContext.get()
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
// Wait a bit for the decommissioning to be triggered in the listener
Thread.sleep(5000)
// MapIndex is explicitly -1 to force the entire host to be decommissioned
// However, this will cause both the tasks in the preceding stage since the host here is
// "localhost" (shortcoming of this single-machine unit test in that all the workers
Expand Down Expand Up @@ -265,23 +283,31 @@ class DecommissionWorkerSuite
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEnd.jobResult match {
case JobSucceeded => jobDone.set(true)
case JobFailed(exception) => logError(s"Job failed", exception)
}
}

protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}

protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}

private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
String = {
s"${stageId}:${stageAttemptId}:" +
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
}

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
logInfo(s"Task started: $signature")
if (isRootStageId(taskStart.stageId)) {
rootTasksStarted.add(taskStart.taskInfo)
handleRootTaskStart(taskStart)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
logInfo(s"Task End $taskSignature")
tasksFinished.add(taskSignature)
if (isRootStageId(taskEnd.stageId)) {
Expand All @@ -291,8 +317,13 @@ class DecommissionWorkerSuite
}

def getTasksFinished(): Seq[String] = {
assert(jobDone.get(), "Job isn't successfully done yet")
tasksFinished.asScala.toSeq
tasksFinished.asScala.toList
}

def waitForJobDone(): Unit = {
eventually(timeout(10.seconds), interval(100.milliseconds)) {
assert(jobDone.get(), "Job isn't successfully done yet")
}
}
}

Expand Down

0 comments on commit aaacf30

Please sign in to comment.