Skip to content

Commit

Permalink
[SPARK-18761][BRANCH-2.0] Introduce "task reaper" to oversee task kil…
Browse files Browse the repository at this point in the history
…ling in executors

Branch-2.0 backport of #16189; original description follows:

## What changes were proposed in this pull request?

Spark's current task cancellation / task killing mechanism is "best effort" because some tasks may not be interruptible or may not respond to their "killed" flags being set. If a significant fraction of a cluster's task slots are occupied by tasks that have been marked as killed but remain running then this can lead to a situation where new jobs and tasks are starved of resources that are being used by these zombie tasks.

This patch aims to address this problem by adding a "task reaper" mechanism to executors. At a high-level, task killing now launches a new thread which attempts to kill the task and then watches the task and periodically checks whether it has been killed. The TaskReaper will periodically re-attempt to call `TaskRunner.kill()` and will log warnings if the task keeps running. I modified TaskRunner to rename its thread at the start of the task, allowing TaskReaper to take a thread dump and filter it in order to log stacktraces from the exact task thread that we are waiting to finish. If the task has not stopped after a configurable timeout then the TaskReaper will throw an exception to trigger executor JVM death, thereby forcibly freeing any resources consumed by the zombie tasks.

This feature is flagged off by default and is controlled by four new configurations under the `spark.task.reaper.*` namespace. See the updated `configuration.md` doc for details.

## How was this patch tested?

Tested via a new test case in `JobCancellationSuite`, plus manual testing.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #16358 from JoshRosen/cancellation-branch-2.0.
  • Loading branch information
JoshRosen authored and yhuai committed Dec 20, 2016
1 parent 1f0c5fa commit 678d91c
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 14 deletions.
169 changes: 160 additions & 9 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Expand Up @@ -84,6 +84,16 @@ private[spark] class Executor(
// Start worker thread pool
private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
private val executorSource = new ExecutorSource(threadPool, executorId)
// Pool used for threads that supervise task killing / cancellation
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
// For tasks which are in the process of being killed, this map holds the most recently created
// TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
// a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
// the integrity of the map's internal state). The purpose of this map is to prevent the creation
// of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
// track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
// create. The map key is a task id.
private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()

if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
Expand All @@ -93,6 +103,9 @@ private[spark] class Executor(
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false)

// Whether to monitor killed / interrupted tasks
private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false)

// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
Expand Down Expand Up @@ -148,9 +161,27 @@ private[spark] class Executor(
}

def killTask(taskId: Long, interruptThread: Boolean): Unit = {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill(interruptThread)
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
val shouldCreateReaper = taskReaperForTask.get(taskId) match {
case None => true
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
}
if (shouldCreateReaper) {
val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
taskReaperForTask(taskId) = taskReaper
Some(taskReaper)
} else {
None
}
}
// Execute the TaskReaper from outside of the synchronized block.
maybeNewTaskReaper.foreach(taskReaperPool.execute)
} else {
taskRunner.kill(interruptThread = interruptThread)
}
}
}

Expand All @@ -161,12 +192,7 @@ private[spark] class Executor(
* @param interruptThread whether to interrupt the task thread
*/
def killAllTasks(interruptThread: Boolean) : Unit = {
// kill all the running tasks
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner != null) {
taskRunner.kill(interruptThread)
}
}
runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
}

def stop(): Unit = {
Expand All @@ -192,13 +218,21 @@ private[spark] class Executor(
serializedTask: ByteBuffer)
extends Runnable {

val threadName = s"Executor task launch worker for task $taskId"

/** Whether this task has been killed. */
@volatile private var killed = false

@volatile private var threadId: Long = -1

def getThreadId: Long = threadId

/** Whether this task has been finished. */
@GuardedBy("TaskRunner.this")
private var finished = false

def isFinished: Boolean = synchronized { finished }

/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _

Expand Down Expand Up @@ -229,9 +263,15 @@ private[spark] class Executor(
// ClosedByInterruptException during execBackend.statusUpdate which causes
// Executor to crash
Thread.interrupted()
// Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
// is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
// is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
notifyAll()
}

override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
Expand Down Expand Up @@ -411,6 +451,117 @@ private[spark] class Executor(
}
}

/**
* Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
* sending a Thread.interrupt(), and monitoring the task until it finishes.
*
* Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
* may not be interruptable or may not respond to their "killed" flags being set. If a significant
* fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
* remain running then this can lead to a situation where new jobs and tasks are starved of
* resources that are being used by these zombie tasks.
*
* The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
* tasks. For backwards-compatibility / backportability this component is disabled by default
* and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
*
* A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
* a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
* in case kill is called twice with different values for the `interrupt` parameter.
*
* Once created, a TaskReaper will run until its supervised task has finished running. If the
* TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
* `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
* if the supervised task never exits.
*/
private class TaskReaper(
taskRunner: TaskRunner,
val interruptThread: Boolean)
extends Runnable {

private[this] val taskId: Long = taskRunner.taskId

private[this] val killPollingIntervalMs: Long =
conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s")

private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1")

private[this] val takeThreadDump: Boolean =
conf.getBoolean("spark.task.reaper.threadDump", true)

override def run(): Unit = {
val startTimeMs = System.currentTimeMillis()
def elapsedTimeMs = System.currentTimeMillis() - startTimeMs
def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs
try {
// Only attempt to kill the task once. If interruptThread = false then a second kill
// attempt would be a no-op and if interruptThread = true then it may not be safe or
// effective to interrupt multiple times:
taskRunner.kill(interruptThread = interruptThread)
// Monitor the killed task until it exits. The synchronization logic here is complicated
// because we don't want to synchronize on the taskRunner while possibly taking a thread
// dump, but we also need to be careful to avoid races between checking whether the task
// has finished and wait()ing for it to finish.
var finished: Boolean = false
while (!finished && !timeoutExceeded()) {
taskRunner.synchronized {
// We need to synchronize on the TaskRunner while checking whether the task has
// finished in order to avoid a race where the task is marked as finished right after
// we check and before we call wait().
if (taskRunner.isFinished) {
finished = true
} else {
taskRunner.wait(killPollingIntervalMs)
}
}
if (taskRunner.isFinished) {
finished = true
} else {
logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms")
if (takeThreadDump) {
try {
Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
if (thread.threadName == taskRunner.threadName) {
logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}")
}
}
} catch {
case NonFatal(e) =>
logWarning("Exception thrown while obtaining thread dump: ", e)
}
}
}
}

if (!taskRunner.isFinished && timeoutExceeded()) {
if (isLocal) {
logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " +
"not killing JVM because we are running in local mode.")
} else {
// In non-local-mode, the exception thrown here will bubble up to the uncaught exception
// handler and cause the executor JVM to exit.
throw new SparkException(
s"Killing executor JVM because killed task $taskId could not be stopped within " +
s"$killTimeoutMs ms.")
}
}
} finally {
// Clean up entries in the taskReaperForTask map.
taskReaperForTask.synchronized {
taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
if (taskReaperInMap eq this) {
taskReaperForTask.remove(taskId)
} else {
// This must have been a TaskReaper where interruptThread == false where a subsequent
// killTask() call for the same task had interruptThread == true and overwrote the
// map entry.
}
}
}
}
}
}

/**
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
Expand Down
26 changes: 21 additions & 5 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.util

import java.io._
import java.lang.management.ManagementFactory
import java.lang.management.{ManagementFactory, ThreadInfo}
import java.net._
import java.nio.ByteBuffer
import java.nio.channels.Channels
Expand Down Expand Up @@ -2112,13 +2112,29 @@ private[spark] object Utils extends Logging {
// We need to filter out null values here because dumpAllThreads() may return null array
// elements for threads that are dead / don't exist.
val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
threadInfo.getThreadState, stackTrace)
threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
}

def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
if (threadId <= 0) {
None
} else {
// The Int.MaxValue here requests the entire untruncated stack trace of the thread:
val threadInfo =
Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue))
threadInfo.map(threadInfoToThreadStackTrace)
}
}

private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
ThreadStackTrace(
threadId = threadInfo.getThreadId,
threadName = threadInfo.getThreadName,
threadState = threadInfo.getThreadState,
stackTrace = stackTrace)
}

/**
* Convert all spark properties set in the given SparkConf to a sequence of java options.
*/
Expand Down
77 changes: 77 additions & 0 deletions core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Expand Up @@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
assert(jobB.get() === 100)
}

test("task reaper kills JVM if killed tasks keep running for too long") {
val conf = new SparkConf()
.set("spark.task.reaper.enabled", "true")
.set("spark.task.reaper.killTimeout", "5s")
sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)

// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem.release()
}
})

// jobA is the one to be cancelled.
val jobA = Future {
sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
sc.parallelize(1 to 10000, 2).map { i =>
while (true) { }
}.count()
}

// Block until both tasks of job A have started and cancel job A.
sem.acquire(2)
// Small delay to ensure tasks actually start executing the task body
Thread.sleep(1000)

sc.clearJobGroup()
val jobB = sc.parallelize(1 to 100, 2).countAsync()
sc.cancelJobGroup("jobA")
val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
assert(e.getMessage contains "cancel")

// Once A is cancelled, job B should finish fairly quickly.
assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
}

test("task reaper will not kill JVM if spark.task.killTimeout == -1") {
val conf = new SparkConf()
.set("spark.task.reaper.enabled", "true")
.set("spark.task.reaper.killTimeout", "-1")
.set("spark.task.reaper.PollingInterval", "1s")
.set("spark.deploy.maxExecutorRetries", "1")
sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)

// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem.release()
}
})

// jobA is the one to be cancelled.
val jobA = Future {
sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
sc.parallelize(1 to 2, 2).map { i =>
val startTime = System.currentTimeMillis()
while (System.currentTimeMillis() < startTime + 10000) { }
}.count()
}

// Block until both tasks of job A have started and cancel job A.
sem.acquire(2)
// Small delay to ensure tasks actually start executing the task body
Thread.sleep(1000)

sc.clearJobGroup()
val jobB = sc.parallelize(1 to 100, 2).countAsync()
sc.cancelJobGroup("jobA")
val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
assert(e.getMessage contains "cancel")

// Once A is cancelled, job B should finish fairly quickly.
assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
}

test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// twoJobsSharingStageSemaphore:
Expand Down

0 comments on commit 678d91c

Please sign in to comment.