Skip to content

Commit

Permalink
Split MemoryManager into ExecutorMemoryManager and TaskMemoryManager:
Browse files Browse the repository at this point in the history
- Implement memory leak detection, with exception vs. logging controlled by
  a configuration option.
  • Loading branch information
JoshRosen committed Apr 28, 2015
1 parent 6e4b192 commit 70a39e4
Show file tree
Hide file tree
Showing 21 changed files with 259 additions and 60 deletions.
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{MemoryManager => UnsafeMemoryManager, MemoryAllocator}
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{RpcUtils, Utils}

/**
Expand Down Expand Up @@ -70,7 +70,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
val unsafeMemoryManager: UnsafeMemoryManager,
val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {

Expand Down Expand Up @@ -384,13 +384,13 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)

val unsafeMemoryManager: UnsafeMemoryManager = {
val executorMemoryManager: ExecutorMemoryManager = {
val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
MemoryAllocator.UNSAFE
} else {
MemoryAllocator.HEAP
}
new UnsafeMemoryManager(allocator)
new ExecutorMemoryManager(allocator)
}

val envInstance = new SparkEnv(
Expand All @@ -409,7 +409,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
unsafeMemoryManager,
executorMemoryManager,
outputCommitCoordinator,
conf)

Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.Serializable

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener


Expand Down Expand Up @@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable {
/** ::DeveloperApi:: */
@DeveloperApi
def taskMetrics(): TaskMetrics

/**
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager
}
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

import scala.collection.mutable.ArrayBuffer
Expand All @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
Expand Down
21 changes: 20 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._

/**
Expand Down Expand Up @@ -179,6 +180,7 @@ private[spark] class Executor(
}

override def run(): Unit = {
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
Expand All @@ -191,6 +193,7 @@ private[spark] class Executor(
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.setTaskMemoryManager(taskMemoryManager)

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
Expand All @@ -207,7 +210,23 @@ private[spark] class Executor(

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
var succeeded: Boolean = false
val value = try {
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
succeeded = true
value
} finally {
// Release managed memory used by this task
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (succeeded && freedMemory > 0) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(errMsg)
} else {
logError(errMsg)
}
}
}
val taskFinish = System.currentTimeMillis()

// If the task has been killed, let's fail it.
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat

Expand Down Expand Up @@ -643,15 +644,32 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
attemptNumber = 0, runningLocally = true)
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val taskContext =
new TaskContextImpl(
job.finalStage.id,
job.partitions(0),
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
runningLocally = true)
TaskContext.setTaskContext(taskContext)
var succeeded: Boolean = false
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
succeeded = true
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
TaskContext.unset()
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (succeeded && freedMemory > 0) {
if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
} else {
logError(s"Managed memory leak detected; size = $freedMemory bytes")
}
}
}
} catch {
case e: Exception =>
Expand Down
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* @return the result of the task
*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
context = new TaskContextImpl(
stageId = stageId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
taskMemoryManager = taskMemoryManager,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
Expand All @@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
}
}

private var taskMemoryManager: TaskMemoryManager = _

def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
this.taskMemoryManager = taskMemoryManager
}

def runTask(context: TaskContext): T

def preferredLocations: Seq[TaskLocation] = Nil
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
val context = new TaskContextImpl(0, 0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0, null)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
Expand All @@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))

val context = new TaskContextImpl(0, 0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
Expand All @@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)

val context = new TaskContextImpl(0, 0, 0, 0, true)
val context = new TaskContextImpl(0, 0, 0, 0, null, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}

test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
val context = new TaskContextImpl(0, 0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0, null)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContextImpl(0, 0, 0, 0)
val tContext = new TaskContextImpl(0, 0, 0, 0, null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}

test("all TaskCompletionListeners should be called even if some fail") {
val context = new TaskContextImpl(0, 0, 0, 0)
val context = new TaskContextImpl(0, 0, 0, 0, null)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)

val iterator = new ShuffleBlockFetcherIterator(
new TaskContextImpl(0, 0, 0, 0),
new TaskContextImpl(0, 0, 0, 0, null),
transfer,
blockManager,
blocksByAddress,
Expand Down Expand Up @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))

val taskContext = new TaskContextImpl(0, 0, 0, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
Expand Down Expand Up @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))

val taskContext = new TaskContextImpl(0, 0, 0, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@
<spark.ui.enabled>false</spark.ui.enabled>
<spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress>
<spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
<spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak>
</systemProperties>
<failIfNoTests>false</failIfNoTests>
</configuration>
Expand Down
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.ui.enabled=false",
javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.MemoryManager;
import org.apache.spark.unsafe.memory.TaskMemoryManager;

/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
Expand Down Expand Up @@ -110,7 +110,7 @@ public UnsafeFixedWidthAggregationMap(
Row emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
MemoryManager memoryManager,
TaskMemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
this.emptyAggregationBuffer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.JavaConverters._
import scala.util.Random

import org.apache.spark.unsafe.memory.{MemoryManager, MemoryAllocator}
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}

import org.apache.spark.sql.types._
Expand All @@ -33,15 +33,15 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with Be
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0))

private var memoryManager: MemoryManager = null
private var memoryManager: TaskMemoryManager = null

override def beforeEach(): Unit = {
memoryManager = new MemoryManager(MemoryAllocator.HEAP)
memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
}

override def afterEach(): Unit = {
if (memoryManager != null) {
memoryManager.cleanUpAllPages()
memoryManager.cleanUpAllAllocatedMemory()
memoryManager = null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.sql.execution

import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.MemoryAllocator

case class AggregateEvaluation(
schema: Seq[Attribute],
Expand Down Expand Up @@ -290,7 +289,7 @@ case class GeneratedAggregate(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
SparkEnv.get.unsafeMemoryManager,
TaskContext.get.taskMemoryManager(),
1024 * 16, // initial capacity
false // disable tracking of performance metrics
)
Expand Down
Loading

0 comments on commit 70a39e4

Please sign in to comment.