Skip to content

Commit

Permalink
Add additional defenses against use of freed MemoryBlocks
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jan 8, 2018
1 parent 4f7e758 commit a7f8c07
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 13 deletions.
Expand Up @@ -31,8 +31,7 @@
public class HeapMemoryAllocator implements MemoryAllocator {

@GuardedBy("this")
private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
new HashMap<>();
private final Map<Long, LinkedList<WeakReference<long[]>>> bufferPoolsBySize = new HashMap<>();

private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;

Expand All @@ -49,13 +48,14 @@ private boolean shouldPool(long size) {
public MemoryBlock allocate(long size) throws OutOfMemoryError {
if (shouldPool(size)) {
synchronized (this) {
final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
final LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
if (pool != null) {
while (!pool.isEmpty()) {
final WeakReference<MemoryBlock> blockReference = pool.pop();
final MemoryBlock memory = blockReference.get();
if (memory != null) {
assert (memory.size() == size);
final WeakReference<long[]> arrayReference = pool.pop();
final long[] array = arrayReference.get();
if (array != null) {
assert (array.length * 8L >= size);
MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
}
Expand All @@ -76,18 +76,35 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {

@Override
public void free(MemoryBlock memory) {
assert (memory.obj != null) :
"baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?";
assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
"page has already been freed";
assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
|| (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
"TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()";

final long size = memory.size();
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
}

// Mark the page as freed (so we can detect double-frees).
memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;

// As an additional layer of defense against use-after-free bugs, we mutate the
// MemoryBlock to null out its reference to the long[] array.
long[] array = (long[]) memory.obj;
memory.setObjAndOffset(null, 0);

if (shouldPool(size)) {
synchronized (this) {
LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
if (pool == null) {
pool = new LinkedList<>();
bufferPoolsBySize.put(size, pool);
}
pool.add(new WeakReference<>(memory));
pool.add(new WeakReference<>(array));
}
} else {
// Do nothing
Expand Down
Expand Up @@ -26,14 +26,33 @@
*/
public class MemoryBlock extends MemoryLocation {

/** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */
public static final int NO_PAGE_NUMBER = -1;

/**
* Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager.
* We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator
* can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM
* before being passed to MemoryAllocator.free() (it is an error to allocate a page in
* TaskMemoryManager and then directly free it in a MemoryAllocator without going through
* the TMM freePage() call).
*/
public static final int FREED_IN_TMM_PAGE_NUMBER = -2;

/**
* Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows
* us to detect double-frees.
*/
public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3;

private final long length;

/**
* Optional page number; used when this MemoryBlock represents a page allocated by a
* TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
* which lives in a different package.
*/
public int pageNumber = -1;
public int pageNumber = NO_PAGE_NUMBER;

public MemoryBlock(@Nullable Object obj, long offset, long length) {
super(obj, offset);
Expand Down
Expand Up @@ -38,9 +38,20 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
public void free(MemoryBlock memory) {
assert (memory.obj == null) :
"baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
"page has already been freed";
assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
|| (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
"TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()";

if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
}
Platform.freeMemory(memory.offset);
// As an additional layer of defense against use-after-free bugs, we mutate the
// MemoryBlock to reset its pointer.
memory.offset = 0;
// Mark the page as freed (so we can detect double-frees).
memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
}
}
Expand Up @@ -62,6 +62,52 @@ public void overlappingCopyMemory() {
}
}

@Test
public void onHeapMemoryAllocatorPoolingReUsesLongArrays() {
MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
Object baseObject1 = block1.getBaseObject();
MemoryAllocator.HEAP.free(block1);
MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
Object baseObject2 = block2.getBaseObject();
Assert.assertSame(baseObject1, baseObject2);
MemoryAllocator.HEAP.free(block2);
}

@Test
public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() {
MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
Assert.assertNotNull(block.getBaseObject());
MemoryAllocator.HEAP.free(block);
Assert.assertNull(block.getBaseObject());
Assert.assertEquals(0, block.getBaseOffset());
Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
}

@Test
public void freeingOffHeapMemoryBlockResetsOffset() {
MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
Assert.assertNull(block.getBaseObject());
Assert.assertNotEquals(0, block.getBaseOffset());
MemoryAllocator.UNSAFE.free(block);
Assert.assertNull(block.getBaseObject());
Assert.assertEquals(0, block.getBaseOffset());
Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
}

@Test(expected = AssertionError.class)
public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
MemoryAllocator.HEAP.free(block);
MemoryAllocator.HEAP.free(block);
}

@Test(expected = AssertionError.class)
public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
MemoryAllocator.UNSAFE.free(block);
MemoryAllocator.UNSAFE.free(block);
}

@Test
public void memoryDebugFillEnabledInTest() {
Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED);
Expand All @@ -71,9 +117,11 @@ public void memoryDebugFillEnabledInTest() {
MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);

MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
Object onheap1BaseObject = onheap1.getBaseObject();
long onheap1BaseOffset = onheap1.getBaseOffset();
MemoryAllocator.HEAP.free(onheap1);
Assert.assertEquals(
Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()),
Platform.getByte(onheap1BaseObject, onheap1BaseOffset),
MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
Assert.assertEquals(
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
Expand Up @@ -321,8 +321,12 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
*/
public void freePage(MemoryBlock page, MemoryConsumer consumer) {
assert (page.pageNumber != -1) :
assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
"Called freePage() on a memory block that has already been freed";
assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
"Called freePage() on a memory block that has already been freed";
assert(allocatedPages.get(page.pageNumber));
pageTable[page.pageNumber] = null;
synchronized (this) {
Expand All @@ -332,6 +336,10 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) {
logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
}
long pageSize = page.size();
// Clear the page number before passing the block to the MemoryAllocator's free().
// Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed
// page has been inappropriately directly freed without calling TMM.freePage().
page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
memoryManager.tungstenMemoryAllocator().free(page);
releaseExecutionMemory(pageSize, consumer);
}
Expand All @@ -358,7 +366,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {

@VisibleForTesting
public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page";
return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
}

Expand Down Expand Up @@ -424,6 +432,7 @@ public long cleanUpAllAllocatedMemory() {
for (MemoryBlock page : pageTable) {
if (page != null) {
logger.debug("unreleased page: " + page + " in task " + taskAttemptId);
page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
memoryManager.tungstenMemoryAllocator().free(page);
}
}
Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.junit.Test;

import org.apache.spark.SparkConf;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;

public class TaskMemoryManagerSuite {
Expand Down Expand Up @@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() {
Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
}

@Test
public void freeingPageSetsPageNumberToSpecialConstant() {
final TaskMemoryManager manager = new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
final MemoryBlock dataPage = manager.allocatePage(256, c);
c.freePage(dataPage);
Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber);
}

@Test(expected = AssertionError.class)
public void freeingPageDirectlyInAllocatorTriggersAssertionError() {
final TaskMemoryManager manager = new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
final MemoryBlock dataPage = manager.allocatePage(256, c);
MemoryAllocator.HEAP.free(dataPage);
}

@Test(expected = AssertionError.class)
public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() {
final TaskMemoryManager manager = new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256);
manager.freePage(dataPage, c);
}

@Test
public void cooperativeSpilling() {
final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
Expand Down

0 comments on commit a7f8c07

Please sign in to comment.