diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 833a951d7ec08..fc34ad9cff369 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -156,6 +156,7 @@ public boolean hasNext() { @Override public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = sortBuffer[position]; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 41d0b079835ce..5815c2c487ca3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -91,6 +91,14 @@ private void writeIntToBuffer(int v, int offset) throws IOException { writeBuffer[offset + 3] = (byte)(v >>> 0); } + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ public void write( Object baseObject, long baseOffset, @@ -105,8 +113,8 @@ public void write( writeIntToBuffer(recordLength, 0); writeLongToBuffer(keyPrefix, 4); int dataRemaining = recordLength; - int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; - long recordReadPosition = baseOffset + 4; // skip over record length + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); PlatformDependent.copyMemory( diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 67666e35aaeb9..909500930539c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -34,14 +34,13 @@ public class UnsafeInMemorySorterSuite { - private static String getStringFromDataPage(Object baseObject, long baseOffset) { - final int strLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); - final byte[] strBytes = new byte[strLength]; + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + final byte[] strBytes = new byte[length]; PlatformDependent.copyMemory( baseObject, - baseOffset + 4, + baseOffset, strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + PlatformDependent.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @@ -116,7 +115,7 @@ public int compare(long prefix1, long prefix2) { // position now points to the start of a record (which holds its length). final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); - final String str = getStringFromDataPage(baseObject, position); + final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); sorter.insertRecord(address, partitionId); position += 4 + recordLength; @@ -127,9 +126,8 @@ public int compare(long prefix1, long prefix2) { Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); - // TODO: the logic for how we manipulate record length offsets here is confusing; clean - // this up and clarify it in comments. - final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset() - 4); + final String str = + getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); final long keyPrefix = iter.getKeyPrefix(); assertThat(str, isIn(Arrays.asList(dataToSort))); assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));