From a28daa2c3283ad31659f840e6d401ab48a42ad88 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 20 Sep 2017 13:35:35 +0800 Subject: [PATCH 1/8] [SPARK-22074][Core] Task killed by other attempt task should not be resubmitted --- .../org/apache/spark/scheduler/TaskInfo.scala | 12 +++ .../spark/scheduler/TaskSetManager.scala | 3 +- .../org/apache/spark/scheduler/FakeTask.scala | 17 +++- .../spark/scheduler/TaskSetManagerSuite.scala | 94 +++++++++++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 9843eab4f1346..dec18e7a60ceb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -66,6 +66,12 @@ class TaskInfo( */ var finishTime: Long = 0 + /** + * Set this tag when this task killed by other attempt. This kind of task should not resubmit + * while executor lost. + */ + var killedAttempt = false + var failed = false var killed = false @@ -74,6 +80,10 @@ class TaskInfo( gettingResultTime = time } + private[spark] def markKilledAttempt { + killedAttempt = true + } + private[spark] def markFinished(state: TaskState, time: Long) { // finishTime should be set larger than 0, otherwise "finished" below will return false. assert(time > 0) @@ -93,6 +103,8 @@ class TaskInfo( def running: Boolean = !finished + def needResubmit: Boolean = !killedAttempt + def status: String = { if (running) { if (gettingResult) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3804ea863b4f9..8526c83d48e32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -724,6 +724,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") + attemptInfo.markKilledAttempt sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -910,7 +911,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index)) { + if (successful(index) && info.needResubmit) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index fe6de2bd98850..bf95ab3b3b61f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,8 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.executor.TaskMetrics class FakeTask( @@ -58,4 +57,18 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createShuffleMapTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { + override def index: Int = i + }, prefLocs(i), new Properties, + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + } + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ae43f4cadc037..d5029fb8c5e10 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -744,6 +744,100 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(resubmittedTasks === 0) } + + test("[SPARK-22074] Task killed by other attempt task should not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + sched.backend = mock(classOf[SchedulerBackend]) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Verify that it kills other running attempt + verify(sched.backend).killTask(2, "exec3", true, "another attempt succeeded") + // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 + manager.executorLost("exec3", "host3", SlaveLost()) + // Check the resubmittedTasks + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( From 81ac4dc75f91c888a9fefa805915f9420d71b761 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 21 Sep 2017 11:46:24 +0800 Subject: [PATCH 2/8] Comment rewrite --- core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index dec18e7a60ceb..f83249bc842ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -67,7 +67,8 @@ class TaskInfo( var finishTime: Long = 0 /** - * Set this tag when this task killed by other attempt. This kind of task should not resubmit + * Set this var when the current task killed by other attempt tasks, this happened while we + * set the `spark.speculation` to true. The task killed by others should not resubmit * while executor lost. */ var killedAttempt = false From 039591d4fb2cbe3292b4f0ce33ba605bed895453 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 26 Sep 2017 13:10:00 +0800 Subject: [PATCH 3/8] Simplify code and nit fix --- .../main/scala/org/apache/spark/scheduler/TaskInfo.scala | 6 ++---- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index f83249bc842ce..11acf7375dbdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -71,7 +71,7 @@ class TaskInfo( * set the `spark.speculation` to true. The task killed by others should not resubmit * while executor lost. */ - var killedAttempt = false + var killedByAttempt = false var failed = false @@ -82,7 +82,7 @@ class TaskInfo( } private[spark] def markKilledAttempt { - killedAttempt = true + killedByAttempt = true } private[spark] def markFinished(state: TaskState, time: Long) { @@ -104,8 +104,6 @@ class TaskInfo( def running: Boolean = !finished - def needResubmit: Boolean = !killedAttempt - def status: String = { if (running) { if (gettingResult) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8526c83d48e32..a0b2b27b2ec8b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -911,7 +911,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index) && info.needResubmit) { + if (successful(index) && !info.killedByAttempt) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 From 7191eff18dcad8f680964b2cb8df8c68c27de801 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 26 Sep 2017 20:27:35 +0800 Subject: [PATCH 4/8] Rename the variable --- .../main/scala/org/apache/spark/scheduler/TaskInfo.scala | 6 +++--- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 11acf7375dbdd..73d951e722988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -71,7 +71,7 @@ class TaskInfo( * set the `spark.speculation` to true. The task killed by others should not resubmit * while executor lost. */ - var killedByAttempt = false + var killedByOtherAttempt = false var failed = false @@ -81,8 +81,8 @@ class TaskInfo( gettingResultTime = time } - private[spark] def markKilledAttempt { - killedByAttempt = true + private[spark] def markKilledAttempt: Unit = { + killedByOtherAttempt = true } private[spark] def markFinished(state: TaskState, time: Long) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a0b2b27b2ec8b..a4bc652152e2e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -911,7 +911,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index) && !info.killedByAttempt) { + if (successful(index) && !info.killedByOtherAttempt) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 From 71b0d5821ca4d2738fb608073acbb8f0ba1d8d29 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 28 Sep 2017 10:27:58 +0800 Subject: [PATCH 5/8] Change method signature and rename --- core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala | 2 +- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 73d951e722988..6f3651be2f47f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,7 +81,7 @@ class TaskInfo( gettingResultTime = time } - private[spark] def markKilledAttempt: Unit = { + private[spark] def markKilledByOtherAttempt(): Unit = { killedByOtherAttempt = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a4bc652152e2e..82a4ea5ca150a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -724,7 +724,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - attemptInfo.markKilledAttempt + attemptInfo.markKilledByOtherAttempt sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, From 94fd257c3a8da6ef4473eab72e826af57b10ed47 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 29 Sep 2017 12:33:25 +0800 Subject: [PATCH 6/8] Move the killedByOtherAttempt into TaskSetManager and fix other nits --- .../org/apache/spark/scheduler/TaskInfo.scala | 11 -------- .../spark/scheduler/TaskSetManager.scala | 9 +++++-- .../org/apache/spark/scheduler/FakeTask.scala | 7 +++-- .../spark/scheduler/TaskSetManagerSuite.scala | 26 ++++++++++++++----- 4 files changed, 31 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6f3651be2f47f..9843eab4f1346 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -66,13 +66,6 @@ class TaskInfo( */ var finishTime: Long = 0 - /** - * Set this var when the current task killed by other attempt tasks, this happened while we - * set the `spark.speculation` to true. The task killed by others should not resubmit - * while executor lost. - */ - var killedByOtherAttempt = false - var failed = false var killed = false @@ -81,10 +74,6 @@ class TaskInfo( gettingResultTime = time } - private[spark] def markKilledByOtherAttempt(): Unit = { - killedByOtherAttempt = true - } - private[spark] def markFinished(state: TaskState, time: Long) { // finishTime should be set larger than 0, otherwise "finished" below will return false. assert(time > 0) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82a4ea5ca150a..c01b00f96b84c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -83,6 +83,11 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) + // Set the coresponding index of Boolean var when the task killed by other attempt tasks, + // this happened while we set the `spark.speculation` to true. The task killed by others + // should not resubmit while executor lost. + private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -724,7 +729,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - attemptInfo.markKilledByOtherAttempt + killedByOtherAttempt(index) = true sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -911,7 +916,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index) && !info.killedByOtherAttempt) { + if (successful(index) && !killedByOtherAttempt(index)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index bf95ab3b3b61f..109d4a0a870b8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -58,8 +58,11 @@ object FakeTask { new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } - def createShuffleMapTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, - prefLocs: Seq[TaskLocation]*): TaskSet = { + def createShuffleMapTaskSet( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index d5029fb8c5e10..12ef189149860 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -760,7 +760,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg taskId: Long, executorId: String, interruptThread: Boolean, - reason: String): Unit = {} + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + } }) // Keep track of the number of tasks that are resubmitted, @@ -794,15 +801,20 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg task.metrics.internalAccums } // Offer resources for 4 tasks to start - for ((k, v) <- List( + for ((exec, host) <- Seq( "exec1" -> "host1", "exec1" -> "host1", "exec3" -> "host3", "exec2" -> "host2")) { - val taskOption = manager.resourceOffer(k, v, NO_PREF) + val taskOption = manager.resourceOffer(exec, host, NO_PREF) assert(taskOption.isDefined) val task = taskOption.get - assert(task.executorId === k) + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } } assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) clock.advance(1) @@ -827,11 +839,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(task4.taskId === 4) assert(task4.executorId === "exec2") assert(task4.attemptNumber === 1) - sched.backend = mock(classOf[SchedulerBackend]) // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) - // Verify that it kills other running attempt - verify(sched.backend).killTask(2, "exec3", true, "another attempt succeeded") + // With this successful task end, the sched.backend will kill other running attempt, + // verify the request of killTask(2, "exec3", true, "another attempt succeeded") in + // FakeDAGScheduler subclass // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 manager.executorLost("exec3", "host3", SlaveLost()) // Check the resubmittedTasks From 5d09b3ba410ca69b3c67d219ff2a4ad9998db7d8 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 30 Sep 2017 14:38:49 +0800 Subject: [PATCH 7/8] Nit fix --- .../spark/scheduler/TaskSetManagerSuite.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 12ef189149860..6b5d58620bb84 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -19,20 +19,18 @@ package org.apache.spark.scheduler import java.util.{Properties, Random} -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{AccumulatorV2, ManualClock} import org.mockito.Matchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config -import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -162,7 +160,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { } class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} + import TaskLocality.{ANY, NODE_LOCAL, NO_PREF, PROCESS_LOCAL, RACK_LOCAL} private val conf = new SparkConf @@ -753,20 +751,22 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation.quantile", "0.5") sc.conf.set("spark.speculation", "true") + var killTaskCalled = false val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) sched.initialize(new FakeSchedulerBackend() { override def killTask( - taskId: Long, - executorId: String, - interruptThread: Boolean, - reason: String): Unit = { + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { // Check the only one killTask event in this case, which triggered by // task 2.1 completed. assert(taskId === 2) assert(executorId === "exec3") assert(interruptThread) assert(reason === "another attempt succeeded") + killTaskCalled = true } }) @@ -841,9 +841,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(task4.attemptNumber === 1) // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) - // With this successful task end, the sched.backend will kill other running attempt, - // verify the request of killTask(2, "exec3", true, "another attempt succeeded") in - // FakeDAGScheduler subclass + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 manager.executorLost("exec3", "host3", SlaveLost()) // Check the resubmittedTasks From 1c8c84937e85302f2ac48bcbdbdb5507c9b445e4 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 30 Sep 2017 16:06:32 +0800 Subject: [PATCH 8/8] Fix IDE style change --- .../spark/scheduler/TaskSetManagerSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6b5d58620bb84..d50031047f1be 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -19,18 +19,20 @@ package org.apache.spark.scheduler import java.util.{Properties, Random} -import org.apache.spark._ -import org.apache.spark.internal.{Logging, config} -import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AccumulatorV2, ManualClock} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.mockito.Matchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -160,7 +162,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { } class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, NODE_LOCAL, NO_PREF, PROCESS_LOCAL, RACK_LOCAL} + import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf