diff --git a/core/src/main/java/org/apache/spark/Spillable.java b/core/src/main/java/org/apache/spark/Spillable.java new file mode 100644 index 0000000000000..a343b93f0f1da --- /dev/null +++ b/core/src/main/java/org/apache/spark/Spillable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +/** + * Force to spill contents of memory buffer to disk and release its memory + */ +public interface Spillable { + + /** + * force to spill contents of memory buffer to disk + * @return numBytes bytes of spilled + */ + public long forceSpill(); +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8b..f3466454d8fbf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark._ /** * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling @@ -38,8 +38,48 @@ import org.apache.spark.{Logging, SparkException, SparkConf} private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + // threadId -> memory reserved list + private val threadReservedList = new mutable.HashMap[Long, mutable.ListBuffer[Spillable]]() + def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + /** + * release other Spillable's memory of current thread until freeMemory >= requestedMemory + */ + private[this] def releaseReservedMemory(toGrant: Long, requestMemory: Long): Long = + synchronized { + val threadId = Thread.currentThread().getId + if (toGrant >= requestMemory || !threadReservedList.contains(threadId)){ + toGrant + } else { + // try to release Spillable's memory in current thread to make space for new request + var addMemory = toGrant + while(addMemory < requestMemory && !threadReservedList(threadId).isEmpty ) { + val toSpill = threadReservedList(threadId).remove(0) + val spillMemory = toSpill.forceSpill() + logInfo(s"Thread $threadId forceSpill $spillMemory bytes to be free") + addMemory += spillMemory + } + if (addMemory > requestMemory) { + this.release(addMemory - requestMemory) + addMemory = requestMemory + } + addMemory + } + } + + /** + * add Spillable to memoryReservedList of current thread, when current thread has + * no enough memory, we can release memory of current thread's memoryReservedList + */ + private[spark] def addSpillableToReservedList(spill: Spillable) = synchronized { + val threadId = Thread.currentThread().getId + if (!threadReservedList.contains(threadId)) { + threadReservedList(threadId) = new mutable.ListBuffer[Spillable]() + } + threadReservedList(threadId) += spill + } + /** * Try to acquire up to numBytes memory for the current thread, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory @@ -77,7 +117,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) threadMemory(threadId) += toGrant - return toGrant + return this.releaseReservedMemory(toGrant, numBytes) } else { logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") wait() @@ -86,7 +126,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) threadMemory(threadId) += toGrant - return toGrant + return this.releaseReservedMemory(toGrant, numBytes) } } 0L // Never reached @@ -108,6 +148,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { def releaseMemoryForThisThread(): Unit = synchronized { val threadId = Thread.currentThread().getId threadMemory.remove(threadId) + threadReservedList.remove(threadId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/CollectionSpillable.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/util/collection/Spillable.scala rename to core/src/main/scala/org/apache/spark/util/collection/CollectionSpillable.scala index 747ecf075a397..00cdf35cbcfde 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CollectionSpillable.scala @@ -17,14 +17,13 @@ package org.apache.spark.util.collection -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.{Logging, SparkEnv, Spillable} /** * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ -private[spark] trait Spillable[C] extends Logging { +private[spark] trait CollectionSpillable[C] extends Logging with Spillable{ /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -40,16 +39,10 @@ private[spark] trait Spillable[C] extends Logging { protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory - private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - - // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing - private[this] val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Threshold for this collection's size in bytes before we start tracking its memory usage - // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 - private[this] var myMemoryThreshold = initialMemoryThreshold + private[this] var myMemoryThreshold = 0L // Number of elements read from input since last spill private[this] var _elementsRead = 0L @@ -102,8 +95,8 @@ private[spark] trait Spillable[C] extends Logging { */ private def releaseMemoryForThisThread(): Unit = { // The amount we requested does not include the initial memory tracking threshold - shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) - myMemoryThreshold = initialMemoryThreshold + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L } /** @@ -117,4 +110,18 @@ private[spark] trait Spillable[C] extends Logging { .format(threadId, org.apache.spark.util.Utils.bytesToString(size), _spillCount, if (_spillCount > 1) "s" else "")) } + + /** + * log ForceSpill and return collection's size + */ + protected def logForceSpill(currentMemory: Long): Long = { + _spillCount += 1 + logSpillage(currentMemory) + + _elementsRead = 0 + _memoryBytesSpilled += currentMemory + val freeMemory = myMemoryThreshold + myMemoryThreshold = 0L + freeMemory + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1e4531ef395ae..eccd1353abe15 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -69,7 +69,7 @@ class ExternalAppendOnlyMap[K, V, C]( extends Iterable[(K, C)] with Serializable with Logging - with Spillable[SizeTracker] { + with CollectionSpillable[SizeTracker] { private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] @@ -100,6 +100,8 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + private var memoryOrDiskIter: Option[MemoryOrDiskIterator] = None + /** * Insert the given key and value into the map. */ @@ -151,6 +153,14 @@ class ExternalAppendOnlyMap[K, V, C]( * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ override protected[this] def spill(collection: SizeTracker): Unit = { + val it = currentMap.destructiveSortedIterator(keyComparator) + spilledMaps.append(spillMemoryToDisk(it)) + } + + /** + * spill contents of the in-memory map to a temporary file on disk. + */ + private[this] def spillMemoryToDisk(inMemory: Iterator[(K, C)]): DiskMapIterator = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -171,9 +181,8 @@ class ExternalAppendOnlyMap[K, V, C]( var success = false try { - val it = currentMap.destructiveSortedIterator(keyComparator) - while (it.hasNext) { - val kv = it.next() + while (inMemory.hasNext) { + val kv = inMemory.next() writer.write(kv._1, kv._2) objectsWritten += 1 @@ -203,8 +212,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } } - - spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) + new DiskMapIterator(file, blockId, batchSizes) } def diskBytesSpilled: Long = _diskBytesSpilled @@ -214,13 +222,53 @@ class ExternalAppendOnlyMap[K, V, C]( * If no spill has occurred, simply return the in-memory map's iterator. */ override def iterator: Iterator[(K, C)] = { + shuffleMemoryManager.addSpillableToReservedList(this) if (spilledMaps.isEmpty) { - currentMap.iterator + memoryOrDiskIter = Some(MemoryOrDiskIterator(currentMap.iterator)) + memoryOrDiskIter.get } else { new ExternalIterator() } } + /** + * spill contents of memory map to disk and release its memory + */ + override def forceSpill(): Long = { + var freeMemory = 0L + if (memoryOrDiskIter.isDefined) { + freeMemory = logForceSpill(currentMap.estimateSize()) + memoryOrDiskIter.get.spill() + } + freeMemory + } + + /* + * An iterator that read elements from in-memory iterator or disk iterator when in-memory + * iterator have spilled to disk. + */ + case class MemoryOrDiskIterator(memIter: Iterator[(K, C)]) extends Iterator[(K, C)] { + + var currentIter = memIter + + override def hasNext: Boolean = currentIter.hasNext + + override def next(): (K, C) = currentIter.next() + + private[spark] def spill() = { + if (hasNext) { + currentIter = spillMemoryToDisk(currentIter) + } else { + // in-memory iterator is already drained, release it by giving an empty iterator + currentIter = new Iterator[(K, C)]{ + override def hasNext: Boolean = false + override def next(): (K, C) = null + } + logInfo("nothing in memory iterator, do nothing") + } + } + } + /** * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps */ @@ -232,7 +280,9 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = currentMap.destructiveSortedIterator(keyComparator) + memoryOrDiskIter = Some(MemoryOrDiskIterator( + currentMap.destructiveSortedIterator(keyComparator))) + private val sortedMap = memoryOrDiskIter.get private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 757dec66c203b..70dbb501a0e9d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -93,7 +93,7 @@ private[spark] class ExternalSorter[K, V, C]( ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] + with CollectionSpillable[WritablePartitionedPairCollection[K, C]] with SortShuffleFileWriter[K, V] { private val conf = SparkEnv.get.conf @@ -148,6 +148,9 @@ private[spark] class ExternalSorter[K, V, C]( private var map = new PartitionedAppendOnlyMap[K, C] private var buffer = newBuffer() + private var memoryOrDiskIter: Option[MemoryOrDiskIterator] = None + private var isShuffleSort: Boolean = true + // Total spilling statistics private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled @@ -177,7 +180,7 @@ private[spark] class ExternalSorter[K, V, C]( // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. - private[this] case class SpilledFile( + private[spark] case class SpilledFile( file: File, blockId: BlockId, serializerBatchSizes: Array[Long], @@ -242,76 +245,9 @@ private[spark] class ExternalSorter[K, V, C]( * @param collection whichever collection we're using (map or buffer) */ override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { - // Because these files may be read during shuffle, their compression must be controlled by - // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use - // createTempShuffleBlock here; see SPARK-3426 for more context. - val (blockId, file) = diskBlockManager.createTempShuffleBlock() - - // These variables are reset after each flush - var objectsWritten: Long = 0 - var spillMetrics: ShuffleWriteMetrics = null - var writer: BlockObjectWriter = null - def openWriter(): Unit = { - assert (writer == null && spillMetrics == null) - spillMetrics = new ShuffleWriteMetrics - writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) - } - openWriter() - - // List of batch sizes (bytes) in the order they are written to disk - val batchSizes = new ArrayBuffer[Long] - - // How many elements we have in each partition - val elementsPerPartition = new Array[Long](numPartitions) - - // Flush the disk writer's contents to disk, and update relevant variables. - // The writer is closed at the end of this process, and cannot be reused. - def flush(): Unit = { - val w = writer - writer = null - w.commitAndClose() - _diskBytesSpilled += spillMetrics.shuffleBytesWritten - batchSizes.append(spillMetrics.shuffleBytesWritten) - spillMetrics = null - objectsWritten = 0 - } - - var success = false - try { - val it = collection.destructiveSortedWritablePartitionedIterator(comparator) - while (it.hasNext) { - val partitionId = it.nextPartition() - it.writeNext(writer) - elementsPerPartition(partitionId) += 1 - objectsWritten += 1 - - if (objectsWritten == serializerBatchSize) { - flush() - openWriter() - } - } - if (objectsWritten > 0) { - flush() - } else if (writer != null) { - val w = writer - writer = null - w.revertPartialWritesAndClose() - } - success = true - } finally { - if (!success) { - // This code path only happens if an exception was thrown above before we set success; - // close our stuff and let the exception be thrown further - if (writer != null) { - writer.revertPartialWritesAndClose() - } - if (file.exists()) { - file.delete() - } - } - } - - spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + val spillFile = spillMemoryToDisk(it) + spills.append(spillFile) } /** @@ -603,8 +539,159 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Return an iterator over all the data written to this object, grouped by partition and - * aggregated by the requested aggregator. For each partition we then have an iterator over its + * spill contents of in-memory iterator to a temporary file on disk. + */ + private def spillMemoryToDisk(inMemory: WritablePartitionedIterator): SpilledFile = { + // Because these files may be read during shuffle, their compression must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more context. + val (blockId, file) = diskBlockManager.createTempShuffleBlock() + + // These variables are reset after each flush + var objectsWritten: Long = 0 + var spillMetrics: ShuffleWriteMetrics = null + var writer: BlockObjectWriter = null + def openWriter(): Unit = { + assert (writer == null && spillMetrics == null) + spillMetrics = new ShuffleWriteMetrics + writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) + } + openWriter() + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // How many elements we have in each partition + val elementsPerPartition = new Array[Long](numPartitions) + + // Flush the disk writer's contents to disk, and update relevant variables. + // The writer is closed at the end of this process, and cannot be reused. + def flush(): Unit = { + val w = writer + writer = null + w.commitAndClose() + _diskBytesSpilled += spillMetrics.shuffleBytesWritten + batchSizes.append(spillMetrics.shuffleBytesWritten) + spillMetrics = null + objectsWritten = 0 + } + + var success = false + try { + + while (inMemory.hasNext) { + val partitionId = inMemory.nextPartition() + inMemory.writeNext(writer) + elementsPerPartition(partitionId) += 1 + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + openWriter() + } + } + if (objectsWritten > 0) { + flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } + } + } + SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition) + } + + /** + * Spill in-memory inMemory to a temporary file on disk. + */ + private[this] def spillMemoryToDisk(iterator: Iterator[((Int, K), C)]): SpilledFile = { + + val it = new WritablePartitionedIterator { + private[this] var cur = if (iterator.hasNext) iterator.next() else null + + def writeNext(writer: BlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (iterator.hasNext) iterator.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + + spillMemoryToDisk(it) + } + + /** + * An iterator that read elements from in-memory iterator or disk iterator when in-memory + * iterator have spilled to disk. + */ + case class MemoryOrDiskIterator(memIter: Iterator[((Int, K), C)]) + extends Iterator[((Int, K), C)] { + + var currentIter = memIter + var spillFile: Option[SpilledFile] = None + + override def hasNext: Boolean = currentIter.hasNext + + override def next(): ((Int, K), C) = currentIter.next() + + private[spark] def spill() = { + if (hasNext) { + spillFile = Some(spillMemoryToDisk(currentIter)) + val spillReader = new SpillReader(spillFile.get) + + currentIter = (0 until numPartitions).iterator.flatMap { p => + val iterator = spillReader.readNextPartition() + iterator.map(cur => ((p, cur._1), cur._2)) + } + } else { + // in-memory iterator is already drained, release it by giving an empty iterator + currentIter = new Iterator[((Int, K), C)]{ + override def hasNext: Boolean = false + override def next(): ((Int, K), C) = null + } + logInfo("nothing in memory inMemory, do nothing") + } + } + + private[spark] def cleanup(): Unit = { + spillFile.foreach(_.file.delete()) + } + } + + /** + * spill contents of memory map to disk + */ + override def forceSpill(): Long = { + var freeMemory = 0L + if (memoryOrDiskIter.isDefined) { + val shouldCombine = aggregator.isDefined + if (shouldCombine) { + freeMemory = logForceSpill(map.estimateSize()) + } else { + freeMemory = logForceSpill(buffer.estimateSize()) + } + memoryOrDiskIter.get.spill() + } + freeMemory + } + + /** + * Return an inMemory over all the data written to this object, grouped by partition and + * aggregated by the requested aggregator. For each partition we then have an inMemory over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one * partition without reading the previous one). Guaranteed to return a key-value pair for each * partition, in order of partition ID. @@ -616,26 +703,42 @@ private[spark] class ExternalSorter[K, V, C]( def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer + def changeIterToMemoryOrDiskIter(inMemory: Iterator[((Int, K), C)]) = { + if (isShuffleSort) { + inMemory + } else { + memoryOrDiskIter = Some(MemoryOrDiskIterator(inMemory)) + memoryOrDiskIter.get + } + } + if (spills.isEmpty) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { // The user hasn't requested sorted keys, so only sort by partition ID, not key - groupByPartition(collection.partitionedDestructiveSortedIterator(None)) + groupByPartition(changeIterToMemoryOrDiskIter( + collection.partitionedDestructiveSortedIterator(None))) } else { // We do need to sort by both partition ID and key - groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) + groupByPartition(changeIterToMemoryOrDiskIter( + collection.partitionedDestructiveSortedIterator(Some(keyComparator)))) } } else { // Merge spilled and in-memory data - merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) + merge(spills, changeIterToMemoryOrDiskIter( + collection.partitionedDestructiveSortedIterator(comparator))) } } /** * Return an iterator over all the data written to this object, aggregated by our aggregator. */ - def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + def iterator: Iterator[Product2[K, C]] = { + isShuffleSort = false + shuffleMemoryManager.addSpillableToReservedList(this) + partitionedIterator.flatMap(pair => pair._2) + } /** * Write all the data added into this ExternalSorter into a file in the disk store. This is @@ -693,6 +796,7 @@ private[spark] class ExternalSorter[K, V, C]( def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() + memoryOrDiskIter.foreach(_.cleanup()) } /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..17071d1eed2b4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -23,6 +23,20 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.CountDownLatch import org.apache.spark.SparkFunSuite +import org.apache.spark.Spillable + +private[this] class FakeSpillable extends Spillable { + + var myMemoryThreshold: Long = 0L + + private[spark] def addMemory(currentMemory: Long) = { + myMemoryThreshold += currentMemory + } + + override def forceSpill(): Long = { + return myMemoryThreshold + } +} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { /** Launch a thread with the given body block and return it. */ @@ -307,4 +321,19 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { val granted = manager.tryToAcquire(300L) assert(0 === granted, "granted is negative") } + + test("latter spillable grab full memory of previous spillable") { + val manager = new ShuffleMemoryManager(1000L) + + val spill1 = new FakeSpillable() + + spill1.addMemory(manager.tryToAcquire(700L)) + spill1.addMemory(manager.tryToAcquire(300L)) + manager.addSpillableToReservedList(spill1) + + val granted1 = manager.tryToAcquire(300L) + assert(300L === granted1, "granted memory") + val granted2 = manager.tryToAcquire(800L) + assert(700L === granted2, "granted remained memory") + } }