diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index 34c15e6bbcb0e..8c0940d23420b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -19,9 +19,24 @@ /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). */ final class PackedRecordPointer { + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + /** Bit mask for the lower 40 bits of a long. */ private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL; @@ -55,7 +70,11 @@ public static long packPointer(long recordPointer, int partitionId) { return (((long) partitionId) << 40) | compressedAddress; } - public long packedRecordPointer; + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } public int getPartitionId() { return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); @@ -68,7 +87,4 @@ public long getRecordPointer() { return pageNumber | offsetInPage; } - public int getRecordLength() { - return -1; // TODO - } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 892a78796335b..6e0d8da410231 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -57,8 +57,9 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this / don't duplicate - private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; private final int initialSize; private final int numPartitions; @@ -88,13 +89,13 @@ final class UnsafeShuffleExternalSorter { private long freeSpaceInCurrentPage = 0; public UnsafeShuffleExternalSorter( - TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - BlockManager blockManager, - TaskContext taskContext, - int initialSize, - int numPartitions, - SparkConf conf) throws IOException { + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; @@ -140,8 +141,9 @@ private SpillInfo writeSpillFile() throws IOException { // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer - // records in a byte array. This array only needs to be big enough to hold a single record. - final byte[] arr = new byte[SER_BUFFER_SIZE]; + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; // Because this output will be read during shuffle, its compression codec must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -186,16 +188,23 @@ private SpillInfo writeSpillFile() throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( - memoryManager.getPage(recordPointer), memoryManager.getOffsetInPage(recordPointer)); - PlatformDependent.copyMemory( - memoryManager.getPage(recordPointer), - memoryManager.getOffsetInPage(recordPointer) + 4, // skip over record length - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); - assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, recordLength); + final Object recordPage = memoryManager.getPage(recordPointer); + final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + PlatformDependent.copyMemory( + recordPage, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + assert (writer != null); // To suppress an IntelliJ warning + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java index 862845180584e..a66d74ee44782 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -38,7 +38,7 @@ public PackedRecordPointer newKey() { @Override public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.packedRecordPointer = data[pos]; + reuse.set(data[pos]); return reuse; } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java index d15da8a7ee126..5acbc6c1c4f2f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -95,7 +95,7 @@ public boolean hasNext() { @Override public void loadNext() { - packedRecordPointer.packedRecordPointer = sortBuffer[position]; + packedRecordPointer.set(sortBuffer[position]); position++; } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index f28e63f137bc9..db9f8648a93b4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -54,7 +54,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + @VisibleForTesting + static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private final BlockManager blockManager; @@ -108,19 +109,26 @@ public UnsafeShuffleWriter( this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); } - public void write(Iterator> records) { + public void write(Iterator> records) throws IOException { write(JavaConversions.asScalaIterator(records)); } @Override - public void write(scala.collection.Iterator> records) { + public void write(scala.collection.Iterator> records) throws IOException { try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); } closeAndWriteOutput(); } catch (Exception e) { - PlatformDependent.throwException(e); + // Unfortunately, we have to catch Exception here in order to ensure proper cleanup after + // errors becuase Spark's Scala code, or users' custom Serializers, might throw arbitrary + // unchecked exceptions. + try { + sorter.cleanupAfterError(); + } finally { + throw new IOException("Error during shuffle write", e); + } } } @@ -134,7 +142,7 @@ private void open() throws IOException { 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), sparkConf); - serArray = new byte[SER_BUFFER_SIZE]; + serArray = new byte[MAXIMUM_RECORD_SIZE]; serByteBuffer = ByteBuffer.wrap(serArray); serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index e28a2459cdff9..4cc4ef5f1886e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** @@ -24,6 +26,7 @@ import org.apache.spark.scheduler.MapStatus */ private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ + @throws[IOException] def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index 53554520b22b1..ba1f89d099838 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -34,10 +34,10 @@ public void heap() { final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); - PackedRecordPointer packedPointerWrapper = new PackedRecordPointer(); - packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360); - Assert.assertEquals(360, packedPointerWrapper.getPartitionId()); - Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer()); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + Assert.assertEquals(360, packedPointer.getPartitionId()); + Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer()); memoryManager.cleanUpAllAllocatedMemory(); } @@ -48,10 +48,10 @@ public void offHeap() { final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); - PackedRecordPointer packedPointerWrapper = new PackedRecordPointer(); - packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360); - Assert.assertEquals(360, packedPointerWrapper.getPartitionId()); - Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer()); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + Assert.assertEquals(360, packedPointer.getPartitionId()); + Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer()); memoryManager.cleanUpAllAllocatedMemory(); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index cc8cbc534510b..9002126bb7a4a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.unsafe; import java.io.*; +import java.nio.ByteBuffer; import java.util.*; import scala.*; @@ -287,6 +288,42 @@ public void mergeSpillsWithFileStream() throws Exception { testMergingSpills(false); } + @Test + public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = + new ArrayList>(); + final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + Assert.assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception { + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = + new ArrayList>(); + final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2(1, bytes)); + try { + // Insert a record and force a spill so that there's something to clean up: + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.forceSorterToSpill(); + writer.write(dataToWrite.iterator()); + Assert.fail("Expected exception to be thrown"); + } catch (IOException e) { + // Pass + } + assertSpillFilesWereCleanedUp(); + } + @Test public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { final UnsafeShuffleWriter writer = createWriter(false);