Skip to content

Commit

Permalink
Use TaskAttemptIds to track shuffle memory
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 28, 2015
1 parent c9e8e54 commit 2e1e0f8
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 46 deletions.
Expand Up @@ -265,7 +265,7 @@ private[spark] class PythonRDD(
}
} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
env.shuffleMemoryManager.releaseMemoryForThisTask()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
}
Expand Down
Expand Up @@ -314,7 +314,7 @@ private[spark] class Executor(

} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
env.shuffleMemoryManager.releaseMemoryForThisTask()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
runningTasks.remove(taskId)
Expand Down
Expand Up @@ -19,95 +19,99 @@ package org.apache.spark.shuffle

import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}
import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
* Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
* from this pool and release it as it spills data out. When a task ends, all its memory will be
* released by the Executor.
*
* This class tries to ensure that each thread gets a reasonable share of memory, instead of some
* thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
* This class tries to ensure that each task gets a reasonable share of memory, instead of some
* task ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
* set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
* set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
*/
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes
private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))

private def currenttaskAttemptId(): Long = {
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
}

/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* Try to acquire up to numBytes memory for the current task, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
* in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active threads) before it is forced to spill. This can
* happen if the number of threads increases but an older thread had a lot of memory already.
* in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active tasks) before it is forced to spill. This can
* happen if the number of tasks increases but an older task had a lot of memory already.
*/
def tryToAcquire(numBytes: Long): Long = synchronized {
val threadId = Thread.currentThread().getId
val taskAttemptId = currenttaskAttemptId()
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)

// Add this thread to the threadMemory map just so we can keep an accurate count of the number
// of active threads, to let other threads ramp down their memory in calls to tryToAcquire
if (!threadMemory.contains(threadId)) {
threadMemory(threadId) = 0L
notifyAll() // Will later cause waiting threads to wake up and check numThreads again
// Add this task to the taskMemory map just so we can keep an accurate count of the number
// of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
if (!taskMemory.contains(taskAttemptId)) {
taskMemory(taskAttemptId) = 0L
notifyAll() // Will later cause waiting tasks to wake up and check numThreads again
}

// Keep looping until we're either sure that we don't want to grant this request (because this
// thread would have more than 1 / numActiveThreads of the memory) or we have enough free
// memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
// task would have more than 1 / numActiveTasks of the memory) or we have enough free
// memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
while (true) {
val numActiveThreads = threadMemory.keys.size
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum
val numActiveTasks = taskMemory.keys.size
val curMem = taskMemory(taskAttemptId)
val freeMemory = maxMemory - taskMemory.values.sum

// How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
// How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
// don't let it be negative
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))

if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
// if we can't give it this much now, wait for other threads to free up memory
// (this happens if older threads allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
if (curMem < maxMemory / (2 * numActiveTasks)) {
// We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
// if we can't give it this much now, wait for other tasks to free up memory
// (this happens if older tasks allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
taskMemory(taskAttemptId) += toGrant
return toGrant
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
logInfo(s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
taskMemory(taskAttemptId) += toGrant
return toGrant
}
}
0L // Never reached
}

/** Release numBytes bytes for the current thread. */
/** Release numBytes bytes for the current task. */
def release(numBytes: Long): Unit = synchronized {
val threadId = Thread.currentThread().getId
val curMem = threadMemory.getOrElse(threadId, 0L)
val taskAttemptId = currenttaskAttemptId()
val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}")
}
threadMemory(threadId) -= numBytes
taskMemory(taskAttemptId) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}

/** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
/** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisTask(): Unit = synchronized {
val taskAttemptId = currenttaskAttemptId()
taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
Expand Down
Expand Up @@ -484,7 +484,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}

private def currentTaskAttemptId(): Long = {
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1)
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
}

/**
Expand Down
Expand Up @@ -50,7 +50,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
assert(manager.tryToAcquire(300L) === 300L)
assert(manager.tryToAcquire(300L) === 200L)

manager.releaseMemoryForThisThread()
manager.releaseMemoryForThisTask()
assert(manager.tryToAcquire(1000L) === 1000L)
assert(manager.tryToAcquire(100L) === 0L)
}
Expand Down Expand Up @@ -253,7 +253,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
// Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
// sure the other thread blocks for some time otherwise
Thread.sleep(300)
manager.releaseMemoryForThisThread()
manager.releaseMemoryForThisTask()
}

val t2 = startThread("t2") {
Expand Down

0 comments on commit 2e1e0f8

Please sign in to comment.