Skip to content

Commit

Permalink
[SPARK-9419] ShuffleMemoryManager and MemoryStore should track memory…
Browse files Browse the repository at this point in the history
… on a per-task, not per-thread, basis

Spark's ShuffleMemoryManager and MemoryStore track memory on a per-thread basis, which causes problems in the handful of cases where we have tasks that use multiple threads. In PythonRDD, RRDD, ScriptTransformation, and PipedRDD we consume the input iterator in a separate thread in order to write it to an external process.  As a result, these RDD's input iterators are consumed in a different thread than the thread that created them, which can cause problems in our memory allocation tracking. For example, if allocations are performed in one thread but deallocations are performed in a separate thread then memory may be leaked or we may get errors complaining that more memory was allocated than was freed.

I think that the right way to fix this is to change our accounting to be performed on a per-task instead of per-thread basis.  Note that the current per-thread tracking has caused problems in the past; SPARK-3731 (#2668) fixes a memory leak in PythonRDD that was caused by this issue (that fix is no longer necessary as of this patch).

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7734 from JoshRosen/memory-tracking-fixes and squashes the following commits:

b4b1702 [Josh Rosen] Propagate TaskContext to writer threads.
57c9b4e [Josh Rosen] Merge remote-tracking branch 'origin/master' into memory-tracking-fixes
ed25d3b [Josh Rosen] Address minor PR review comments
44f6497 [Josh Rosen] Fix long line.
7b0f04b [Josh Rosen] Fix ShuffleMemoryManagerSuite
f57f3f2 [Josh Rosen] More thread -> task changes
fa78ee8 [Josh Rosen] Move Executor's cleanup into Task so that TaskContext is defined when cleanup is performed
5e2f01e [Josh Rosen] Fix capitalization
1b0083b [Josh Rosen] Roll back fix in PySpark, which is no longer necessary
2e1e0f8 [Josh Rosen] Use TaskAttemptIds to track shuffle memory
c9e8e54 [Josh Rosen] Use TaskAttemptIds to track unroll memory
  • Loading branch information
JoshRosen authored and rxin committed Jul 29, 2015
1 parent 429b2f0 commit ea49705
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ private[spark] class PythonRDD(

override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
Expand Down Expand Up @@ -263,11 +264,6 @@ private[spark] class PythonRDD(
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
partition: Int): Unit = {

val env = SparkEnv.get
val taskContext = TaskContext.get()
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val stream = new BufferedOutputStream(output, bufferSize)

new Thread("writer for R") {
override def run(): Unit = {
try {
SparkEnv.set(env)
TaskContext.setTaskContext(taskContext)
val dataOut = new DataOutputStream(stream)
dataOut.writeInt(partition)

Expand Down
4 changes: 0 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,6 @@ private[spark] class Executor(
}

} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
runningTasks.remove(taskId)
}
}
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
TaskContext.setTaskContext(context)
val out = new PrintWriter(proc.getOutputStream)

// scalastyle:off println
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap

import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
Expand Down Expand Up @@ -86,7 +86,18 @@ private[spark] abstract class Task[T](
(runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
TaskContext.unset()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for shuffles
SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
}
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
}
} finally {
TaskContext.unset()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,95 +19,101 @@ package org.apache.spark.shuffle

import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}
import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
* Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
* from this pool and release it as it spills data out. When a task ends, all its memory will be
* released by the Executor.
*
* This class tries to ensure that each thread gets a reasonable share of memory, instead of some
* thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
* This class tries to ensure that each task gets a reasonable share of memory, instead of some
* task ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
* set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
* set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
*/
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes
private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))

private def currentTaskAttemptId(): Long = {
// In case this is called on the driver, return an invalid task attempt id.
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
}

/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* Try to acquire up to numBytes memory for the current task, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
* in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active threads) before it is forced to spill. This can
* happen if the number of threads increases but an older thread had a lot of memory already.
* in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active tasks) before it is forced to spill. This can
* happen if the number of tasks increases but an older task had a lot of memory already.
*/
def tryToAcquire(numBytes: Long): Long = synchronized {
val threadId = Thread.currentThread().getId
val taskAttemptId = currentTaskAttemptId()
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)

// Add this thread to the threadMemory map just so we can keep an accurate count of the number
// of active threads, to let other threads ramp down their memory in calls to tryToAcquire
if (!threadMemory.contains(threadId)) {
threadMemory(threadId) = 0L
notifyAll() // Will later cause waiting threads to wake up and check numThreads again
// Add this task to the taskMemory map just so we can keep an accurate count of the number
// of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
if (!taskMemory.contains(taskAttemptId)) {
taskMemory(taskAttemptId) = 0L
notifyAll() // Will later cause waiting tasks to wake up and check numThreads again
}

// Keep looping until we're either sure that we don't want to grant this request (because this
// thread would have more than 1 / numActiveThreads of the memory) or we have enough free
// memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
// task would have more than 1 / numActiveTasks of the memory) or we have enough free
// memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
while (true) {
val numActiveThreads = threadMemory.keys.size
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum
val numActiveTasks = taskMemory.keys.size
val curMem = taskMemory(taskAttemptId)
val freeMemory = maxMemory - taskMemory.values.sum

// How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
// How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
// don't let it be negative
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))

if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
// if we can't give it this much now, wait for other threads to free up memory
// (this happens if older threads allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
if (curMem < maxMemory / (2 * numActiveTasks)) {
// We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
// if we can't give it this much now, wait for other tasks to free up memory
// (this happens if older tasks allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
taskMemory(taskAttemptId) += toGrant
return toGrant
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
logInfo(
s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
taskMemory(taskAttemptId) += toGrant
return toGrant
}
}
0L // Never reached
}

/** Release numBytes bytes for the current thread. */
/** Release numBytes bytes for the current task. */
def release(numBytes: Long): Unit = synchronized {
val threadId = Thread.currentThread().getId
val curMem = threadMemory.getOrElse(threadId, 0L)
val taskAttemptId = currentTaskAttemptId()
val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}")
}
threadMemory(threadId) -= numBytes
taskMemory(taskAttemptId) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}

/** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
/** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisTask(): Unit = synchronized {
val taskAttemptId = currentTaskAttemptId()
taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
Expand Down
Loading

0 comments on commit ea49705

Please sign in to comment.