From 52a99819506aa32e3d146b639cf597948f14c8cd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 18:31:45 -0700 Subject: [PATCH] Fix some bugs in the address packing code. The problem is that TaskMemoryManager expects offsets to include the page base address whereas PackedRecordPointer did not. --- .../shuffle/unsafe/PackedRecordPointer.java | 10 ++-- .../unsafe/PackedRecordPointerSuite.java | 16 ++++-- .../unsafe/memory/TaskMemoryManager.java | 52 +++++++++++++++---- .../unsafe/memory/TaskMemoryManagerSuite.java | 7 ++- 4 files changed, 63 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index 6d61b1b9e34da..4ee6a82c0423e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -68,9 +68,8 @@ public static long packPointer(long recordPointer, int partitionId) { assert (partitionId <= MAXIMUM_PARTITION_ID); // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. - final int pageNumber = (int) ((recordPointer & MASK_LONG_UPPER_13_BITS) >>> 51); - final long compressedAddress = - (((long) pageNumber) << 27) | (recordPointer & MASK_LONG_LOWER_27_BITS); + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); return (((long) partitionId) << 40) | compressedAddress; } @@ -85,9 +84,8 @@ public int getPartitionId() { } public long getRecordPointer() { - final long compressedAddress = packedRecordPointer & MASK_LONG_LOWER_40_BITS; - final long pageNumber = (compressedAddress << 24) & MASK_LONG_UPPER_13_BITS; - final long offsetInPage = compressedAddress & MASK_LONG_LOWER_27_BITS; + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; return pageNumber | offsetInPage; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index 4fda87ab57c49..db9e82759090a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -34,11 +34,15 @@ public void heap() { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); - final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); assertEquals(360, packedPointer.getPartitionId()); - assertEquals(addressInPage1, packedPointer.getRecordPointer()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); memoryManager.cleanUpAllAllocatedMemory(); } @@ -48,11 +52,15 @@ public void offHeap() { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); - final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); assertEquals(360, packedPointer.getPartitionId()); - assertEquals(addressInPage1, packedPointer.getRecordPointer()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); memoryManager.cleanUpAllAllocatedMemory(); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index cfd54035bee99..2906ac8abad1a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -48,10 +48,18 @@ public final class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - /** - * The number of entries in the page table. - */ - private static final int PAGE_TABLE_SIZE = 1 << 13; + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** Maximum supported data page size */ + private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -102,8 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (size >= (1L << 51)) { - throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + if (size > MAXIMUM_PAGE_SIZE) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); } final int pageNumber; @@ -168,15 +177,36 @@ public void free(MemoryBlock memory) { /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); + } return encodePageNumberAndOffset(page.pageNumber, offsetInPage); } @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } /** @@ -185,7 +215,7 @@ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final Object page = pageTable[pageNumber].getBaseObject(); assert (page != null); @@ -200,11 +230,13 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { - final long offsetInPage = (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { return offsetInPage; } else { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); return pageTable[pageNumber].getBaseOffset() + offsetInPage; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 8ace8625abb64..06fb081183659 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -43,9 +43,12 @@ public void encodePageNumberAndOffsetOffHeap() { final TaskMemoryManager manager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); final MemoryBlock dataPage = manager.allocatePage(256); - final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); Assert.assertEquals(null, manager.getPage(encodedAddress)); - Assert.assertEquals(dataPage.getBaseOffset() + 64, manager.getOffsetInPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); } @Test