From 02749c1785f63020c7b998e83353bb9f14e87304 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 30 Jun 2015 23:37:43 +0800 Subject: [PATCH] add unit test --- .../spark/shuffle/ShuffleMemoryManager.scala | 4 +-- .../shuffle/ShuffleMemoryManagerSuite.scala | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 80b5c677f0c66..727063706aa56 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -116,7 +116,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) threadMemory(threadId) += toGrant - return toGrant + return this.releaseReservedMemory(toGrant, numBytes) } else { logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") wait() @@ -125,7 +125,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) threadMemory(threadId) += toGrant - return toGrant + return this.releaseReservedMemory(toGrant, numBytes) } } 0L // Never reached diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..c31294b4233bc 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -23,6 +23,18 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.CountDownLatch import org.apache.spark.SparkFunSuite +import org.apache.spark.Spillable + +class FakeSpillable extends Spillable { + var myMemoryThreshold: Long = 0L + def addMemory(currentMemory: Long) = { + myMemoryThreshold += currentMemory + } + + override def forceSpill(): Long = { + return myMemoryThreshold + } +} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { /** Launch a thread with the given body block and return it. */ @@ -307,4 +319,17 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { val granted = manager.tryToAcquire(300L) assert(0 === granted, "granted is negative") } + + test("latter spillable grab full memory of previous spillable") { + val manager = new ShuffleMemoryManager(1000L) + val spill1 = new FakeSpillable() + val spill2 = new FakeSpillable() + spill1.addMemory(manager.tryToAcquire(700L)) + spill1.addMemory(manager.tryToAcquire(300L)) + manager.addSpillableToReservedList(spill1) + val granted1 = manager.tryToAcquire(300L) + assert(300L === granted1, "granted memory") + val granted2 = manager.tryToAcquire(800L) + assert(700L === granted2, "granted remained memory") + } }