Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangxb1987 committed Aug 20, 2018
1 parent 6b8fbbf commit 32ea946
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 56 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Expand Up @@ -1863,7 +1863,8 @@ abstract class RDD[T: ClassTag](

// From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long
// RDD chain.
@transient protected lazy val isBarrier_ : Boolean = dependencies.exists(_.rdd.isBarrier())
@transient protected lazy val isBarrier_ : Boolean =
dependencies.filter(!_.isInstanceOf[ShuffleDependency[_, _, _]]).exists(_.rdd.isBarrier())
}


Expand Down
125 changes: 70 additions & 55 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Expand Up @@ -1478,9 +1478,11 @@ private[spark] class DAGScheduler(
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)

case failedResultStage: ResultStage =>
// Mark all the partitions of the result stage to be not finished, to ensure retry
// all the tasks on resubmitted stage attempt.
failedResultStage.activeJob.map(_.resetAllPartitions())
// Abort the failed result stage since we may have committed output for some
// partitions.
val reason = "Could not recover from a failed barrier ResultStage. Most recent " +
s"failure reason: $failureMessage"
abortStage(failedResultStage, reason, None)
}
}

Expand Down Expand Up @@ -1553,62 +1555,75 @@ private[spark] class DAGScheduler(

// Always fail the current stage and retry all the tasks when a barrier task fail.
val failedStage = stageIdToStage(task.stageId)
logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " +
"failed.")
val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" +
failure.toErrorString
try {
// killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask.
val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) failed."
taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason)
} catch {
case e: UnsupportedOperationException =>
// Cannot continue with barrier stage if failed to cancel zombie barrier tasks.
// TODO SPARK-24877 leave the zombie tasks and ignore their completion events.
logWarning(s"Could not kill all tasks for stage $stageId", e)
abortStage(failedStage, "Could not kill zombie barrier tasks for stage " +
s"$failedStage (${failedStage.name})", Some(e))
}
markStageAsFinished(failedStage, Some(message))
if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ${failedStage.latestInfo.attemptNumber}) running")
} else {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " +
"failed.")
val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" +
failure.toErrorString
try {
// killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask.
val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " +
"failed."
taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason)
} catch {
case e: UnsupportedOperationException =>
// Cannot continue with barrier stage if failed to cancel zombie barrier tasks.
// TODO SPARK-24877 leave the zombie tasks and ignore their completion events.
logWarning(s"Could not kill all tasks for stage $stageId", e)
abortStage(failedStage, "Could not kill zombie barrier tasks for stage " +
s"$failedStage (${failedStage.name})", Some(e))
}
markStageAsFinished(failedStage, Some(message))

failedStage.failedAttemptIds.add(task.stageAttemptId)
// TODO Refactor the failure handling logic to combine similar code with that of
// FetchFailed.
val shouldAbortStage =
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
disallowStageRetryForTest
failedStage.failedAttemptIds.add(task.stageAttemptId)
// TODO Refactor the failure handling logic to combine similar code with that of
// FetchFailed.
val shouldAbortStage =
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
disallowStageRetryForTest

if (shouldAbortStage) {
val abortMessage = if (disallowStageRetryForTest) {
"Barrier stage will not retry stage due to testing config. Most recent failure " +
s"reason: $message"
if (shouldAbortStage) {
val abortMessage = if (disallowStageRetryForTest) {
"Barrier stage will not retry stage due to testing config. Most recent failure " +
s"reason: $message"
} else {
s"""$failedStage (${failedStage.name})
|has failed the maximum allowable number of
|times: $maxConsecutiveStageAttempts.
|Most recent failure reason: $message
""".stripMargin.replaceAll("\n", " ")
}
abortStage(failedStage, abortMessage, None)
} else {
s"""$failedStage (${failedStage.name})
|has failed the maximum allowable number of
|times: $maxConsecutiveStageAttempts.
|Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ")
}
abortStage(failedStage, abortMessage, None)
} else {
failedStage match {
case failedMapStage: ShuffleMapStage =>
// Mark all the map as broken in the map stage, to ensure retry all the tasks on
// resubmitted stage attempt.
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)

case failedResultStage: ResultStage =>
// Mark all the partitions of the result stage to be not finished, to ensure retry
// all the tasks on resubmitted stage attempt.
failedResultStage.activeJob.map(_.resetAllPartitions())
}
failedStage match {
case failedMapStage: ShuffleMapStage =>
// Mark all the map as broken in the map stage, to ensure retry all the tasks on
// resubmitted stage attempt.
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)

// update failedStages and make sure a ResubmitFailedStages event is enqueued
failedStages += failedStage
logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " +
"failure.")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
case failedResultStage: ResultStage =>
// Abort the failed result stage since we may have committed output for some
// partitions.
val reason = "Could not recover from a failed barrier ResultStage. Most recent " +
s"failure reason: $message"
abortStage(failedResultStage, reason, None)
}
// In case multiple task failures triggered for a single stage attempt, ensure we only
// resubmit the failed stage once.
val noResubmitEnqueued = !failedStages.contains(failedStage)
failedStages += failedStage
if (noResubmitEnqueued) {
logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " +
"failure.")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
}
}

case Resubmitted =>
Expand Down
Expand Up @@ -893,6 +893,10 @@ private[spark] class TaskSetManager(
None
}

if (tasks(index).isBarrier) {
isZombie = true
}

sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)

if (!isZombie && reason.countTowardsTaskFailures) {
Expand Down
106 changes: 106 additions & 0 deletions core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
Expand Up @@ -1119,6 +1119,33 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assertDataStructuresEmpty()
}

test("Fail the job if a barrier ResultTask failed") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
.barrier()
.mapPartitions(iter => iter)
submit(reduceRdd, Array(0, 1))

// Complete the map stage.
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostA", 2))))
assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))

// The first ResultTask fails
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
TaskKilled("test"),
null))

// Assert the stage has been cancelled.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failure.getMessage.startsWith("Job aborted due to stage failure: Could not recover " +
"from a failed barrier ResultStage."))
}

/**
* This tests the case where another FetchFailed comes in while the map stage is getting
* re-run.
Expand Down Expand Up @@ -2521,6 +2548,85 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
}

test("Barrier task failures from the same stage attempt don't trigger multiple stage retries") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
submit(reduceRdd, Array(0, 1))

val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

// The first map task fails with TaskKilled.
runEvent(makeCompletionEvent(
taskSets(0).tasks(0),
TaskKilled("test"),
null))
assert(sparkListener.failedStages === Seq(0))

// The second map task fails with TaskKilled.
runEvent(makeCompletionEvent(
taskSets(0).tasks(1),
TaskKilled("test"),
null))

// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)

// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
}

test("Barrier task failures from a previous stage attempt don't trigger stage retry") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
submit(reduceRdd, Array(0, 1))

val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

// The first map task fails with TaskKilled.
runEvent(makeCompletionEvent(
taskSets(0).tasks(0),
TaskKilled("test"),
null))
assert(sparkListener.failedStages === Seq(0))

// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)

// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)

// The second map task fails with TaskKilled.
runEvent(makeCompletionEvent(
taskSets(0).tasks(1),
TaskKilled("test"),
null))

// The second map task failure doesn't trigger stage retry.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down
Expand Up @@ -1118,4 +1118,22 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(!tsm.isZombie)
assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined)
}

test("mark taskset for a barrier stage as zombie in case a task fails") {
val taskScheduler = setupScheduler()

val attempt = FakeTask.createBarrierTaskSet(3)
taskScheduler.submitTasks(attempt)

val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get
val offers = (0 until 3).map{ idx =>
WorkerOffer(s"exec-$idx", s"host-$idx", 1, Some(s"192.168.0.101:4962$idx"))
}
taskScheduler.resourceOffers(offers)
assert(tsm.runningTasks === 3)

// Fail a task from the stage attempt.
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test"))
assert(tsm.isZombie)
}
}

0 comments on commit 32ea946

Please sign in to comment.