Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
lianhuiwang committed Jun 30, 2015
1 parent 4ede7ea commit 02749c1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
return this.releaseReservedMemory(toGrant, numBytes)
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
Expand All @@ -125,7 +125,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
return this.releaseReservedMemory(toGrant, numBytes)
}
}
0L // Never reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.CountDownLatch

import org.apache.spark.SparkFunSuite
import org.apache.spark.Spillable

class FakeSpillable extends Spillable {
var myMemoryThreshold: Long = 0L
def addMemory(currentMemory: Long) = {
myMemoryThreshold += currentMemory
}

override def forceSpill(): Long = {
return myMemoryThreshold
}
}

class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
/** Launch a thread with the given body block and return it. */
Expand Down Expand Up @@ -307,4 +319,17 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
val granted = manager.tryToAcquire(300L)
assert(0 === granted, "granted is negative")
}

test("latter spillable grab full memory of previous spillable") {
val manager = new ShuffleMemoryManager(1000L)
val spill1 = new FakeSpillable()
val spill2 = new FakeSpillable()
spill1.addMemory(manager.tryToAcquire(700L))
spill1.addMemory(manager.tryToAcquire(300L))
manager.addSpillableToReservedList(spill1)
val granted1 = manager.tryToAcquire(300L)
assert(300L === granted1, "granted memory")
val granted2 = manager.tryToAcquire(800L)
assert(700L === granted2, "granted remained memory")
}
}

0 comments on commit 02749c1

Please sign in to comment.