Skip to content

Commit

Permalink
[SPARK-46052][CORE] Remove function TaskScheduler.killAllTaskAttempts
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR removes the interface `TaskScheduler.killAllTaskAttempts` and its implementations. And replace it with `TaskScheduler.cancelTasks`. This PR also removes "abort stage" from `TaskScheduler.cancelTasks` but move it to after the call of `TaskScheduler.cancelTasks` with a control flag `spark.legacy.scheduler.stage.abortAfterCancelTasks` (`true` by default to keep the same behaviour for now). Because "abort stage" is not necessary while canceling tasks, see the comment at #43954 (comment).

Besides, this PR fixes a bug which pontentially launching new tasks after killing all the tasks in the stage attempt. This PR fixes it by marking it as zombie (i.e., `suspend()`) after the killing.

### Why are the changes needed?

Spark has two functions to kill all tasks in a Stage:
* `cancelTasks`: Not only kill all the running tasks in all the stage attempts but also abort all the stage attempts
*  `killAllTaskAttempts`: Only kill all the running tasks in all the stage attemtps but won't abort the attempts.

However, there's no use case in Spark that a stage would launch new tasks after its all tasks get killed. So I think we can replace `killAllTaskAttempts` with `cancelTasks` directly.

### Does this PR introduce _any_ user-facing change?
No. `TaskScheduler` is internal.

### How was this patch tested?

Pass existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43954 from Ngone51/remove_killAllTaskAttempts.

Lead-authored-by: Yi Wu <yi.wu@databricks.com>
Co-authored-by: wuyi <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed Jan 12, 2024
1 parent 9e68a4c commit 96f34bb
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 87 deletions.
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,16 @@ package object config {
.toSequence
.createWithDefault("org.apache.spark.sql.connect.client" :: Nil)

private[spark] val LEGACY_ABORT_STAGE_AFTER_KILL_TASKS =
ConfigBuilder("spark.scheduler.stage.legacyAbortAfterKillTasks")
.doc("Whether to abort a stage after TaskScheduler.killAllTaskAttempts(). This is " +
"used to restore the original behavior in case there are any regressions after " +
"abort stage is removed")
.version("4.0.0")
.internal()
.booleanConf
.createWithDefault(true)

private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION =
ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled")
.internal()
Expand Down
20 changes: 16 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.internal.config.RDD_CACHE_VISIBILITY_TRACKING_ENABLED
import org.apache.spark.internal.config.{LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, RDD_CACHE_VISIBILITY_TRACKING_ENABLED}
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
import org.apache.spark.network.shuffle.{BlockStoreClient, MergeFinalizerListener}
import org.apache.spark.network.shuffle.protocol.MergeStatuses
Expand Down Expand Up @@ -273,7 +273,13 @@ private[spark] class DAGScheduler(
private val messageScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message")

private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
private[spark] var eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
// Used for test only. Some tests uses the same thread of the event poster to
// process the events to ensure the deterministic behavior during the test.
private[spark] def setEventProcessLoop(loop: DAGSchedulerEventProcessLoop): Unit = {
eventProcessLoop = loop
}

taskScheduler.setDAGScheduler(this)

private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true)
Expand Down Expand Up @@ -321,6 +327,9 @@ private[spark] class DAGScheduler(
private val trackingCacheVisibility: Boolean =
sc.getConf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED)

/** Whether to abort a stage after canceling all of its tasks. */
private val legacyAbortStageAfterKillTasks = sc.getConf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS)

/**
* Called by the TaskSetManager to report task's starting.
*/
Expand Down Expand Up @@ -2860,8 +2869,11 @@ private[spark] class DAGScheduler(
// This stage is only used by the job, so finish the stage if it is running.
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptTaskThread(job), reason)
try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask
taskScheduler.killAllTaskAttempts(stageId, shouldInterruptTaskThread(job), reason)
if (legacyAbortStageAfterKillTasks) {
stageFailed(stageId, reason)
}
markStageAsFinished(stage, Some(reason))
} catch {
case e: UnsupportedOperationException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit

// Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage.
// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit

/**
* Kills a task attempt.
* Throw UnsupportedOperationException if the backend doesn't support kill a task.
Expand All @@ -66,7 +62,7 @@ private[spark] trait TaskScheduler {
*/
def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean

// Kill all the running task attempts in a stage.
// Kill all the tasks in all the stage attempts of the same stage Id
// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,31 @@ private[spark] class TaskSchedulerImpl(
new TaskSetManager(this, taskSet, maxTaskFailures, healthTrackerOpt, clock)
}

override def cancelTasks(
// Kill all the tasks in all the stage attempts of the same stage Id. Note stage attempts won't
// be aborted but will be marked as zombie. The stage attempt will be finished and cleaned up
// once all the tasks has been finished. The stage attempt could be aborted after the call of
// `killAllTaskAttempts` if required.
override def killAllTaskAttempts(
stageId: Int,
interruptThread: Boolean,
reason: String): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
// Kill all running tasks for the stage.
killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled: " + reason)
// Cancel all attempts for the stage.
logInfo(s"Killing all running tasks in stage $stageId: $reason")
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task.
// 2. The task set manager has been created but no tasks have been scheduled. In this case,
// simply continue.
tsm.runningTasksSet.foreach { tid =>
taskIdToExecutorId.get(tid).foreach { execId =>
backend.killTask(tid, execId, interruptThread, s"Stage cancelled: $reason")
}
}
tsm.suspend()
logInfo("Stage %s.%s was cancelled".format(stageId, tsm.taskSet.stageAttemptId))
}
}
}
Expand All @@ -327,27 +340,6 @@ private[spark] class TaskSchedulerImpl(
}
}

override def killAllTaskAttempts(
stageId: Int,
interruptThread: Boolean,
reason: String): Unit = synchronized {
logInfo(s"Killing all running tasks in stage $stageId: $reason")
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task.
// 2. The task set manager has been created but no tasks have been scheduled. In this case,
// simply continue.
tsm.runningTasksSet.foreach { tid =>
taskIdToExecutorId.get(tid).foreach { execId =>
backend.killTask(tid, execId, interruptThread, reason)
}
}
}
}
}

override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,17 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}

// Suspends this TSM to avoid launching new tasks.
//
// Unlike `abort()`, this function intentionally to not notify DAGScheduler to avoid
// redundant operations. So the invocation to this function should assume DAGScheduler
// already knows about this TSM failure. For example, this function can be called from
// `TaskScheduler.killAllTaskAttempts` by DAGScheduler.
def suspend(): Unit = {
isZombie = true
maybeFinishTaskSet()
}

/** If the given task ID is not in the set of running tasks, adds it.
*
* Used to keep track of the number of running tasks, for enforcing scheduling policies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft

sc.addSparkListener(new SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
// release taskCancelledSemaphore when cancelTasks event has been posted
// release taskCancelledSemaphore when killAllTaskAttempts event has been posted
if (stageCompleted.stageInfo.stageId == 1) {
taskCancelledSemaphore.release(numElements)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ trait TempLocalSparkContext extends BeforeAndAfterEach
*/
def sc: SparkContext = {
if (_sc == null) {
_sc = new SparkContext(_conf)
_sc = new SparkContext(conf)
}
_sc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}

import scala.annotation.meta.param
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
import scala.jdk.CollectionConverters._
import scala.language.reflectiveCalls
import scala.util.control.NonFatal
Expand All @@ -39,7 +39,7 @@ import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.ExecutorMetrics
import org.apache.spark.internal.config
import org.apache.spark.internal.config.Tests
import org.apache.spark.internal.config.{LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, Tests}
import org.apache.spark.network.shuffle.ExternalBlockStoreClient
import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, ResourceProfileBuilder, TaskResourceProfile, TaskResourceRequests}
Expand All @@ -54,12 +54,30 @@ import org.apache.spark.util.ArrayImplicits._
class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
extends DAGSchedulerEventProcessLoop(dagScheduler) {

dagScheduler.setEventProcessLoop(this)

private var isProcessing = false
private val eventQueue = new ListBuffer[DAGSchedulerEvent]()


override def post(event: DAGSchedulerEvent): Unit = {
try {
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
if (isProcessing) {
// `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially. So we should
// buffer events for sequent processing later instead of processing them recursively.
eventQueue += event
} else {
try {
isProcessing = true
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
} finally {
isProcessing = false
}
if (eventQueue.nonEmpty) {
post(eventQueue.remove(0))
}
}
}

Expand Down Expand Up @@ -168,7 +186,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
/** Set of TaskSets the DAGScheduler has requested executed. */
val taskSets = scala.collection.mutable.Buffer[TaskSet]()

/** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */
/** Stages for which the DAGScheduler has called TaskScheduler.killAllTaskAttempts(). */
val cancelledStages = new HashSet[Int]()

val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()
Expand All @@ -189,13 +207,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
override def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit = {
cancelledStages += stageId
}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
cancelledStages += stageId
}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
val tasks = ts.tasks.filter(_.partitionId == partitionId)
Expand Down Expand Up @@ -867,15 +884,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
override def submitTasks(taskSet: TaskSet): Unit = {
taskSets += taskSet
}
override def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit = {
throw new UnsupportedOperationException
}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
throw new UnsupportedOperationException
}
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
throw new UnsupportedOperationException
}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
Expand Down Expand Up @@ -5035,6 +5049,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
}
}

class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite {
override def conf: SparkConf = super.conf.set(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, false)
}

object DAGSchedulerSuite {
val mergerLocs = ArrayBuffer[BlockManagerId]()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ private class DummyTaskScheduler extends TaskScheduler {
override def start(): Unit = {}
override def stop(exitCode: Int): Unit = {}
override def submitTasks(taskSet: TaskSet): Unit = {}
override def cancelTasks(stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext
assert(1 === taskDescriptions.length)
}

test("cancelTasks shall kill all the running tasks and fail the stage") {
test("killAllTaskAttempts shall kill all the running tasks") {
val taskScheduler = setupScheduler()

taskScheduler.initialize(new FakeSchedulerBackend {
Expand All @@ -1677,43 +1677,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext
val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get
assert(2 === tsm.runningTasks)

taskScheduler.cancelTasks(0, false, "test message")
taskScheduler.killAllTaskAttempts(0, false, "test message")
assert(0 === tsm.runningTasks)
assert(tsm.isZombie)
assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty)
}

test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") {
val taskScheduler = setupScheduler()

taskScheduler.initialize(new FakeSchedulerBackend {
override def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit = {
// Since we only submit one stage attempt, the following call is sufficient to mark the
// task as killed.
taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId)
}
})

val attempt1 = FakeTask.createTaskSet(10)
taskScheduler.submitTasks(attempt1)

val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
new WorkerOffer("executor1", "host1", 1))
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(2 === taskDescriptions.length)
val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get
assert(2 === tsm.runningTasks)

taskScheduler.killAllTaskAttempts(0, false, "test")
assert(0 === tsm.runningTasks)
assert(!tsm.isZombie)
assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined)
}

test("mark taskset for a barrier stage as zombie in case a task fails") {
val taskScheduler = setupScheduler()

Expand Down

0 comments on commit 96f34bb

Please sign in to comment.