diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e3cba4547d98a..0c4d28f786edd 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -40,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.{MemoryManager => UnsafeMemoryManager, MemoryAllocator} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{RpcUtils, Utils} /** @@ -70,7 +70,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, - val unsafeMemoryManager: UnsafeMemoryManager, + val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -384,13 +384,13 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - val unsafeMemoryManager: UnsafeMemoryManager = { + val executorMemoryManager: ExecutorMemoryManager = { val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { MemoryAllocator.UNSAFE } else { MemoryAllocator.HEAP } - new UnsafeMemoryManager(allocator) + new ExecutorMemoryManager(allocator) } val envInstance = new SparkEnv( @@ -409,7 +409,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, - unsafeMemoryManager, + executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a446313..d09e17dea0911 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + /** + * Returns the manager for this task's managed memory. + */ + private[spark] def taskMemoryManager(): TaskMemoryManager } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebcd..b4d572cb52313 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import scala.collection.mutable.ArrayBuffer @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl( val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 327d155b38c22..c687ce9fab5bb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -179,6 +180,7 @@ private[spark] class Executor( } override def run(): Unit = { + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() @@ -191,6 +193,7 @@ private[spark] class Executor( val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. @@ -207,7 +210,23 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + var succeeded: Boolean = false + val value = try { + val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + succeeded = true + value + } finally { + // Release managed memory used by this task + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (succeeded && freedMemory > 0) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a32f8936fb0e..f63e894568d71 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -643,15 +644,32 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, - attemptNumber = 0, runningLocally = true) + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskContext = + new TaskContextImpl( + job.finalStage.id, + job.partitions(0), + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + runningLocally = true) TaskContext.setTaskContext(taskContext) + var succeeded: Boolean = false try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) + succeeded = true job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() TaskContext.unset() + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (succeeded && freedMemory > 0) { + if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") + } else { + logError(s"Managed memory leak detected; size = $freedMemory bytes") + } + } } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 8b592867ee31d..c4187a0cfab69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @return the result of the task */ final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) + context = new TaskContextImpl( + stageId = stageId, + partitionId = partitionId, + taskAttemptId = taskAttemptId, + attemptNumber = attemptNumber, + taskMemoryManager = taskMemoryManager, + runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() @@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex } } + private var taskMemoryManager: TaskMemoryManager = _ + + def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { + this.taskMemoryManager = taskMemoryManager + } + def runTask(context: TaskContext): T def preferredLocations: Seq[TaskLocation] = Nil diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8a4f2a08fe701..34ac9361d46c6 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1009,7 +1009,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 70529d9216591..668ddf9f5f0a9 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index aea76c1adcc09..85eb2a1d07ba4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 057e226916027..83ae8701243e5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 37b593b2c5f79..2080c432d77db 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0), + new TaskContextImpl(0, 0, 0, 0, null), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/pom.xml b/pom.xml index 155670e745cf8..92275ad4400f6 100644 --- a/pom.xml +++ b/pom.xml @@ -1206,6 +1206,7 @@ false false true + true false diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e2ffff8be14a5..b7dbcd9bc562a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -496,6 +496,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 0a4ab84f76cbe..299ff3728a6d9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.MemoryManager; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -110,7 +110,7 @@ public UnsafeFixedWidthAggregationMap( Row emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - MemoryManager memoryManager, + TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { this.emptyAggregationBuffer = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index e7ea1680ee481..7a19e511eb8b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random -import org.apache.spark.unsafe.memory.{MemoryManager, MemoryAllocator} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} import org.apache.spark.sql.types._ @@ -33,15 +33,15 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with Be private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) - private var memoryManager: MemoryManager = null + private var memoryManager: TaskMemoryManager = null override def beforeEach(): Unit = { - memoryManager = new MemoryManager(MemoryAllocator.HEAP) + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) } override def afterEach(): Unit = { if (memoryManager != null) { - memoryManager.cleanUpAllPages() + memoryManager.cleanUpAllAllocatedMemory() memoryManager = null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 226e41f9b09f0..8822a593ee4ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkEnv +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.MemoryAllocator case class AggregateEvaluation( schema: Seq[Attribute], @@ -290,7 +289,7 @@ case class GeneratedAggregate( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, - SparkEnv.get.unsafeMemoryManager, + TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics ) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index f464e34e43cd3..821b161c82371 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -48,7 +48,7 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; - private final MemoryManager memoryManager; + private final TaskMemoryManager memoryManager; /** * A linked list for tracking all allocated data pages so that we can free all of our memory. @@ -135,7 +135,7 @@ public final class BytesToBytesMap { private long numHashCollisions = 0; public BytesToBytesMap( - MemoryManager memoryManager, + TaskMemoryManager memoryManager, int initialCapacity, double loadFactor, boolean enablePerfMetrics) { @@ -146,12 +146,12 @@ public BytesToBytesMap( allocate(initialCapacity); } - public BytesToBytesMap(MemoryManager memoryManager, int initialCapacity) { + public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { this(memoryManager, initialCapacity, 0.70, false); } public BytesToBytesMap( - MemoryManager memoryManager, + TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); @@ -438,8 +438,8 @@ public void putNewKey( */ private void allocate(int capacity) { capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); - longArray = new LongArray(memoryManager.allocator.allocate(capacity * 8 * 2)); - bitset = new BitSet(memoryManager.allocator.allocate(capacity / 8).zero()); + longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + bitset = new BitSet(memoryManager.allocate(capacity / 8).zero()); this.growthThreshold = (int) (capacity * loadFactor); this.mask = capacity - 1; @@ -453,11 +453,11 @@ private void allocate(int capacity) { */ public void free() { if (longArray != null) { - memoryManager.allocator.free(longArray.memoryBlock()); + memoryManager.free(longArray.memoryBlock()); longArray = null; } if (bitset != null) { - memoryManager.allocator.free(bitset.memoryBlock()); + memoryManager.free(bitset.memoryBlock()); bitset = null; } Iterator dataPagesIterator = dataPages.iterator(); @@ -544,8 +544,8 @@ private void growAndRehash() { } // Deallocate the old data structures. - memoryManager.allocator.free(oldLongArray.memoryBlock()); - memoryManager.allocator.free(oldBitSet.memoryBlock()); + memoryManager.free(oldLongArray.memoryBlock()); + memoryManager.free(oldBitSet.memoryBlock()); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java new file mode 100644 index 0000000000000..62c29c8cc1e4d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * Manages memory for an executor. Individual operators / tasks allocate memory through + * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. + */ +public class ExecutorMemoryManager { + + /** + * Allocator, exposed for enabling untracked allocations of temporary data structures. + */ + public final MemoryAllocator allocator; + + /** + * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. + */ + final boolean inHeap; + + /** + * Construct a new ExecutorMemoryManager. + * + * @param allocator the allocator that will be used + */ + public ExecutorMemoryManager(MemoryAllocator allocator) { + this.inHeap = allocator instanceof HeapMemoryAllocator; + this.allocator = allocator; + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError { + return allocator.allocate(size); + } + + void free(MemoryBlock memory) { + allocator.free(memory); + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java similarity index 69% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryManager.java rename to unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index f3893caf119d0..9224988e6ad69 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -17,13 +17,13 @@ package org.apache.spark.unsafe.memory; -import java.util.BitSet; +import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Manages the lifecycle of data pages exchanged between operators. + * Manages the memory allocated by an individual task. *

* Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is @@ -43,9 +43,9 @@ * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is * approximately 35 terabytes of memory. */ -public final class MemoryManager { +public final class TaskMemoryManager { - private final Logger logger = LoggerFactory.getLogger(MemoryManager.class); + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); /** * The number of entries in the page table. @@ -74,9 +74,12 @@ public final class MemoryManager { private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); /** - * Allocator, exposed for enabling untracked allocations of temporary data structures. + * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean + * up leaked memory. */ - public final MemoryAllocator allocator; + private final HashSet allocatedNonPageMemory = new HashSet(); + + private final ExecutorMemoryManager executorMemoryManager; /** * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods @@ -88,9 +91,9 @@ public final class MemoryManager { /** * Construct a new MemoryManager. */ - public MemoryManager(MemoryAllocator allocator) { - this.inHeap = allocator instanceof HeapMemoryAllocator; - this.allocator = allocator; + public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { + this.inHeap = executorMemoryManager.inHeap; + this.executorMemoryManager = executorMemoryManager; } /** @@ -114,7 +117,7 @@ public MemoryBlock allocatePage(long size) { } allocatedPages.set(pageNumber); } - final MemoryBlock page = allocator.allocate(size); + final MemoryBlock page = executorMemoryManager.allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isDebugEnabled()) { @@ -124,7 +127,7 @@ public MemoryBlock allocatePage(long size) { } /** - * Free a block of memory allocated via {@link MemoryManager#allocatePage(long)}. + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. */ public void freePage(MemoryBlock page) { if (logger.isTraceEnabled()) { @@ -132,7 +135,7 @@ public void freePage(MemoryBlock page) { } assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - allocator.free(page); + executorMemoryManager.free(page); synchronized (this) { allocatedPages.clear(page.pageNumber); } @@ -142,6 +145,31 @@ public void freePage(MemoryBlock page) { } } + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended + * to be used for allocating operators' internal data structures. For data pages that you want to + * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since + * that will enable intra-memory pointers (see + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's + * top-level Javadoc for more details). + */ + public MemoryBlock allocate(long size) throws OutOfMemoryError { + final MemoryBlock memory = executorMemoryManager.allocate(size); + allocatedNonPageMemory.add(memory); + return memory; + } + + /** + * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. + */ + public void free(MemoryBlock memory) { + assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; + executorMemoryManager.free(memory); + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } + /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. @@ -157,7 +185,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { /** * Get the page associated with an address encoded by - * {@link MemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { @@ -173,7 +201,7 @@ public Object getPage(long pagePlusOffsetAddress) { /** * Get the offset associated with an address encoded by - * {@link MemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { if (inHeap) { @@ -184,13 +212,26 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { } /** - * Clean up all pages. This shouldn't be called in production code and is only exposed for tests. + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. */ - public void cleanUpAllPages() { + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; for (MemoryBlock page : pageTable) { if (page != null) { + freedBytes += page.size(); freePage(page); } } + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } + return freedBytes; } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 96fa85302e36b..c59e12182c497 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -28,26 +28,27 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.PlatformDependent; +import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.MemoryManager; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.TaskMemoryManager; public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); - private MemoryManager memoryManager; + private TaskMemoryManager memoryManager; @Before public void setup() { - memoryManager = new MemoryManager(getMemoryAllocator()); + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); } @After public void tearDown() { if (memoryManager != null) { - memoryManager.cleanUpAllPages(); + memoryManager.cleanUpAllAllocatedMemory(); memoryManager = null; } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java new file mode 100644 index 0000000000000..932882f1ca248 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.junit.Assert; +import org.junit.Test; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedNonPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocate(1024); // leak memory + Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + +}