diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bbd79c8b9653..7c8cfc9f208f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2621,6 +2621,16 @@ package object config { .toSequence .createWithDefault("org.apache.spark.sql.connect.client" :: Nil) + private[spark] val LEGACY_ABORT_STAGE_AFTER_KILL_TASKS = + ConfigBuilder("spark.scheduler.stage.legacyAbortAfterKillTasks") + .doc("Whether to abort a stage after TaskScheduler.killAllTaskAttempts(). This is " + + "used to restore the original behavior in case there are any regressions after " + + "abort stage is removed") + .version("4.0.0") + .internal() + .booleanConf + .createWithDefault(true) + private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION = ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled") .internal() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1a51220cdf74..e728d921d290 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.internal.config -import org.apache.spark.internal.config.RDD_CACHE_VISIBILITY_TRACKING_ENABLED +import org.apache.spark.internal.config.{LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, RDD_CACHE_VISIBILITY_TRACKING_ENABLED} import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.network.shuffle.{BlockStoreClient, MergeFinalizerListener} import org.apache.spark.network.shuffle.protocol.MergeStatuses @@ -273,7 +273,13 @@ private[spark] class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] var eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + // Used for test only. Some tests uses the same thread of the event poster to + // process the events to ensure the deterministic behavior during the test. + private[spark] def setEventProcessLoop(loop: DAGSchedulerEventProcessLoop): Unit = { + eventProcessLoop = loop + } + taskScheduler.setDAGScheduler(this) private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true) @@ -321,6 +327,9 @@ private[spark] class DAGScheduler( private val trackingCacheVisibility: Boolean = sc.getConf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) + /** Whether to abort a stage after canceling all of its tasks. */ + private val legacyAbortStageAfterKillTasks = sc.getConf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) + /** * Called by the TaskSetManager to report task's starting. */ @@ -2860,8 +2869,11 @@ private[spark] class DAGScheduler( // This stage is only used by the job, so finish the stage if it is running. val stage = stageIdToStage(stageId) if (runningStages.contains(stage)) { - try { // cancelTasks will fail if a SchedulerBackend does not implement killTask - taskScheduler.cancelTasks(stageId, shouldInterruptTaskThread(job), reason) + try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask + taskScheduler.killAllTaskAttempts(stageId, shouldInterruptTaskThread(job), reason) + if (legacyAbortStageAfterKillTasks) { + stageFailed(stageId, reason) + } markStageAsFinished(stage, Some(reason)) } catch { case e: UnsupportedOperationException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 38c7eb77c62d..1e6de9ef46f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,10 +54,6 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage. - // Throw UnsupportedOperationException if the backend doesn't support kill tasks. - def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit - /** * Kills a task attempt. * Throw UnsupportedOperationException if the backend doesn't support kill a task. @@ -66,7 +62,7 @@ private[spark] trait TaskScheduler { */ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean - // Kill all the running task attempts in a stage. + // Kill all the tasks in all the stage attempts of the same stage Id // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 21f62097a4bf..2fd99db74c36 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -296,18 +296,31 @@ private[spark] class TaskSchedulerImpl( new TaskSetManager(this, taskSet, maxTaskFailures, healthTrackerOpt, clock) } - override def cancelTasks( + // Kill all the tasks in all the stage attempts of the same stage Id. Note stage attempts won't + // be aborted but will be marked as zombie. The stage attempt will be finished and cleaned up + // once all the tasks has been finished. The stage attempt could be aborted after the call of + // `killAllTaskAttempts` if required. + override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = synchronized { logInfo("Cancelling stage " + stageId) // Kill all running tasks for the stage. - killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled: " + reason) - // Cancel all attempts for the stage. + logInfo(s"Killing all running tasks in stage $stageId: $reason") taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks have been scheduled. In this case, + // simply continue. + tsm.runningTasksSet.foreach { tid => + taskIdToExecutorId.get(tid).foreach { execId => + backend.killTask(tid, execId, interruptThread, s"Stage cancelled: $reason") + } + } + tsm.suspend() + logInfo("Stage %s.%s was cancelled".format(stageId, tsm.taskSet.stageAttemptId)) } } } @@ -327,27 +340,6 @@ private[spark] class TaskSchedulerImpl( } } - override def killAllTaskAttempts( - stageId: Int, - interruptThread: Boolean, - reason: String): Unit = synchronized { - logInfo(s"Killing all running tasks in stage $stageId: $reason") - taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => - attempts.foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task. - // 2. The task set manager has been created but no tasks have been scheduled. In this case, - // simply continue. - tsm.runningTasksSet.foreach { tid => - taskIdToExecutorId.get(tid).foreach { execId => - backend.killTask(tid, execId, interruptThread, reason) - } - } - } - } - } - override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 390689cb8f72..b8ba6375e27a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1059,6 +1059,17 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + // Suspends this TSM to avoid launching new tasks. + // + // Unlike `abort()`, this function intentionally to not notify DAGScheduler to avoid + // redundant operations. So the invocation to this function should assume DAGScheduler + // already knows about this TSM failure. For example, this function can be called from + // `TaskScheduler.killAllTaskAttempts` by DAGScheduler. + def suspend(): Unit = { + isZombie = true + maybeFinishTaskSet() + } + /** If the given task ID is not in the set of running tasks, adds it. * * Used to keep track of the number of running tasks, for enforcing scheduling policies. diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 997fda93bc9c..c15fdf098bb5 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -547,7 +547,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.addSparkListener(new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { - // release taskCancelledSemaphore when cancelTasks event has been posted + // release taskCancelledSemaphore when killAllTaskAttempts event has been posted if (stageCompleted.stageInfo.stageId == 1) { taskCancelledSemaphore.release(numElements) } diff --git a/core/src/test/scala/org/apache/spark/TempLocalSparkContext.scala b/core/src/test/scala/org/apache/spark/TempLocalSparkContext.scala index 6d5fcd1edfb0..80da24cd33f7 100644 --- a/core/src/test/scala/org/apache/spark/TempLocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/TempLocalSparkContext.scala @@ -51,7 +51,7 @@ trait TempLocalSparkContext extends BeforeAndAfterEach */ def sc: SparkContext = { if (_sc == null) { - _sc = new SparkContext(_conf) + _sc = new SparkContext(conf) } _sc } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c55f627075e8..ee037b7dafcd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference} import scala.annotation.meta.param -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.jdk.CollectionConverters._ import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -39,7 +39,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config -import org.apache.spark.internal.config.Tests +import org.apache.spark.internal.config.{LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, Tests} import org.apache.spark.network.shuffle.ExternalBlockStoreClient import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, ResourceProfileBuilder, TaskResourceProfile, TaskResourceRequests} @@ -54,12 +54,30 @@ import org.apache.spark.util.ArrayImplicits._ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { + dagScheduler.setEventProcessLoop(this) + + private var isProcessing = false + private val eventQueue = new ListBuffer[DAGSchedulerEvent]() + + override def post(event: DAGSchedulerEvent): Unit = { - try { - // Forward event to `onReceive` directly to avoid processing event asynchronously. - onReceive(event) - } catch { - case NonFatal(e) => onError(e) + if (isProcessing) { + // `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially. So we should + // buffer events for sequent processing later instead of processing them recursively. + eventQueue += event + } else { + try { + isProcessing = true + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) + } finally { + isProcessing = false + } + if (eventQueue.nonEmpty) { + post(eventQueue.remove(0)) + } } } @@ -168,7 +186,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() - /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ + /** Stages for which the DAGScheduler has called TaskScheduler.killAllTaskAttempts(). */ val cancelledStages = new HashSet[Int]() val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]() @@ -189,13 +207,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit = { - cancelledStages += stageId - } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( - stageId: Int, interruptThread: Boolean, reason: String): Unit = {} + stageId: Int, interruptThread: Boolean, reason: String): Unit = { + cancelledStages += stageId + } override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { taskSets.filter(_.stageId == stageId).lastOption.foreach { ts => val tasks = ts.tasks.filter(_.partitionId == partitionId) @@ -867,15 +884,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti override def submitTasks(taskSet: TaskSet): Unit = { taskSets += taskSet } - override def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit = { - throw new UnsupportedOperationException - } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = { throw new UnsupportedOperationException } override def killAllTaskAttempts( - stageId: Int, interruptThread: Boolean, reason: String): Unit = { + stageId: Int, interruptThread: Boolean, reason: String): Unit = { throw new UnsupportedOperationException } override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { @@ -5035,6 +5049,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } +class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite { + override def conf: SparkConf = super.conf.set(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, false) +} + object DAGSchedulerSuite { val mergerLocs = ArrayBuffer[BlockManagerId]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index a5c9331e41e1..30973cc963fb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -85,7 +85,6 @@ private class DummyTaskScheduler extends TaskScheduler { override def start(): Unit = {} override def stop(exitCode: Int): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} - override def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 72d0354c5577..7d354f27ff73 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1651,7 +1651,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext assert(1 === taskDescriptions.length) } - test("cancelTasks shall kill all the running tasks and fail the stage") { + test("killAllTaskAttempts shall kill all the running tasks") { val taskScheduler = setupScheduler() taskScheduler.initialize(new FakeSchedulerBackend { @@ -1677,43 +1677,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get assert(2 === tsm.runningTasks) - taskScheduler.cancelTasks(0, false, "test message") + taskScheduler.killAllTaskAttempts(0, false, "test message") assert(0 === tsm.runningTasks) assert(tsm.isZombie) assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty) } - test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") { - val taskScheduler = setupScheduler() - - taskScheduler.initialize(new FakeSchedulerBackend { - override def killTask( - taskId: Long, - executorId: String, - interruptThread: Boolean, - reason: String): Unit = { - // Since we only submit one stage attempt, the following call is sufficient to mark the - // task as killed. - taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) - } - }) - - val attempt1 = FakeTask.createTaskSet(10) - taskScheduler.submitTasks(attempt1) - - val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), - new WorkerOffer("executor1", "host1", 1)) - val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten - assert(2 === taskDescriptions.length) - val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get - assert(2 === tsm.runningTasks) - - taskScheduler.killAllTaskAttempts(0, false, "test") - assert(0 === tsm.runningTasks) - assert(!tsm.isZombie) - assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined) - } - test("mark taskset for a barrier stage as zombie in case a task fails") { val taskScheduler = setupScheduler()