diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
index 5e8c090405098..a1b5266631164 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
@@ -21,6 +21,9 @@
import java.io.File;
+/**
+ * Metadata for a block of data written by {@link UnsafeShuffleSpillWriter}.
+ */
final class SpillInfo {
final long[] partitionLengths;
final File file;
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java
index fd2c170bd2e41..05cf2e7d0d3cc 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java
@@ -17,30 +17,41 @@
package org.apache.spark.shuffle.unsafe;
-import com.google.common.annotations.VisibleForTesting;
+import java.io.File;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.LinkedList;
+
+import org.apache.spark.storage.*;
+import scala.Tuple2;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.BlockId;
-import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.BlockObjectWriter;
-import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import scala.Tuple2;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Iterator;
-import java.util.LinkedList;
/**
- * External sorter based on {@link UnsafeShuffleSorter}.
+ * An external sorter that is specialized for sort-based shuffle.
+ *
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link UnsafeShuffleSorter}). The sorted records are then written
+ * to a single output file (or multiple files, if we've spilled). The format of the output files is
+ * the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ *
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
*/
public final class UnsafeShuffleSpillWriter {
@@ -51,23 +62,31 @@ public final class UnsafeShuffleSpillWriter {
private final int initialSize;
private final int numPartitions;
- private UnsafeShuffleSorter sorter;
-
private final TaskMemoryManager memoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
- private final LinkedList allocatedPages = new LinkedList();
private final boolean spillingEnabled;
- private final int fileBufferSize;
private ShuffleWriteMetrics writeMetrics;
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSize;
- private MemoryBlock currentPage = null;
- private long currentPagePosition = -1;
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList allocatedPages = new LinkedList();
private final LinkedList spills = new LinkedList();
+ // All three of these variables are reset after spilling:
+ private UnsafeShuffleSorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+
public UnsafeShuffleSpillWriter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
@@ -90,6 +109,10 @@ public UnsafeShuffleSpillWriter(
// TODO: metrics tracking + integration with shuffle write metrics
+ /**
+ * Allocates a new sorter. Called when opening the spill writer for the first time and after
+ * each spill.
+ */
private void openSorter() throws IOException {
this.writeMetrics = new ShuffleWriteMetrics();
// TODO: connect write metrics to task metrics?
@@ -106,22 +129,41 @@ private void openSorter() throws IOException {
this.sorter = new UnsafeShuffleSorter(initialSize);
}
+ /**
+ * Sorts the in-memory records, writes the sorted records to a spill file, and frees the in-memory
+ * data structures associated with this sort. New data structures are not automatically allocated.
+ */
private SpillInfo writeSpillFile() throws IOException {
- final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords = sorter.getSortedIterator();
+ // This call performs the actual sort.
+ final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords =
+ sorter.getSortedIterator();
- int currentPartition = -1;
+ // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+ // after SPARK-5581 is fixed.
BlockObjectWriter writer = null;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // records in a byte array. This array only needs to be big enough to hold a single record.
final byte[] arr = new byte[SER_BUFFER_SIZE];
- final Tuple2 spilledFileInfo =
- blockManager.diskBlockManager().createTempLocalBlock();
+ // Because this output will be read during shuffle, its compression codec must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more details.
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempShuffleBlock();
final File file = spilledFileInfo._2();
final BlockId blockId = spilledFileInfo._1();
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
final SerializerInstance ser = new DummySerializerInstance();
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics);
+ int currentPartition = -1;
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
@@ -153,7 +195,9 @@ private SpillInfo writeSpillFile() throws IOException {
if (writer != null) {
writer.commitAndClose();
- // TODO: comment and explain why our handling of empty spills, etc.
+ // If `writeSpillFile()` was called from `closeAndGetSpills()` and no records were inserted,
+ // then the spill file might be empty. Note that it might be better to avoid calling
+ // writeSpillFile() in that case.
if (currentPartition != -1) {
spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
spills.add(spillInfo);
@@ -162,24 +206,30 @@ private SpillInfo writeSpillFile() throws IOException {
return spillInfo;
}
- @VisibleForTesting
- public void spill() throws IOException {
- final SpillInfo spillInfo = writeSpillFile();
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ private void spill() throws IOException {
+ final long threadId = Thread.currentThread().getId();
+ logger.info("Thread " + threadId + " spilling sort data of " +
+ org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" +
+ (spills.size() + (spills.size() > 1 ? " times" : " time")) + " so far)");
+ final SpillInfo spillInfo = writeSpillFile();
final long sorterMemoryUsage = sorter.getMemoryUsage();
sorter = null;
shuffleMemoryManager.release(sorterMemoryUsage);
final long spillSize = freeMemory();
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
taskContext.taskMetrics().incDiskBytesSpilled(spillInfo.file.length());
- final long threadId = Thread.currentThread().getId();
- // TODO: messy; log _before_ spill
- logger.info("Thread " + threadId + " spilling in-memory map of " +
- org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" +
- (spills.size() + ((spills.size() > 1) ? " times" : " time")) + " so far)");
+
openSorter();
}
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * PAGE_SIZE);
+ }
+
private long freeMemory() {
long memoryFreed = 0;
final Iterator iter = allocatedPages.iterator();
@@ -194,7 +244,15 @@ private long freeMemory() {
return memoryFreed;
}
- private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter. If there is
+ * insufficient space, either allocate more memory or spill the current sort data (if spilling
+ * is enabled), then insert the record.
+ */
+ private void ensureSpaceInDataPage(int requiredSpace) throws IOException {
+ // TODO: we should re-order the `if` cases in this function so that the most common case (there
+ // is enough space) appears first.
+
// TODO: merge these steps to first calculate total memory requirements for this insert,
// then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
// data page.
@@ -219,7 +277,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
}
if (requiredSpace > PAGE_SIZE) {
// TODO: throw a more specific exception?
- throw new Exception("Required space " + requiredSpace + " is greater than page size (" +
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
PAGE_SIZE + ")");
} else if (requiredSpace > spaceInCurrentPage) {
if (spillingEnabled) {
@@ -230,7 +288,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
if (memoryAcquiredAfterSpill != PAGE_SIZE) {
shuffleMemoryManager.release(memoryAcquiredAfterSpill);
- throw new Exception("Can't allocate memory!");
+ throw new IOException("Can't allocate memory!");
}
}
}
@@ -241,11 +299,14 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
}
}
+ /**
+ * Write a record to the shuffle sorter.
+ */
public void insertRecord(
Object recordBaseObject,
long recordBaseOffset,
int lengthInBytes,
- int prefix) throws Exception {
+ int partitionId) throws IOException {
// Need 4 bytes to store the record length.
ensureSpaceInDataPage(lengthInBytes + 4);
@@ -262,12 +323,20 @@ public void insertRecord(
lengthInBytes);
currentPagePosition += lengthInBytes;
- sorter.insertRecord(recordAddress, prefix);
+ sorter.insertRecord(recordAddress, partitionId);
}
+ /**
+ * Close the sorter, causing any buffered data to be sorted and written out to disk.
+ *
+ * @return metadata for the spill files written by this sorter. If no records were ever inserted
+ * into this sorter, then this will return an empty array.
+ * @throws IOException
+ */
public SpillInfo[] closeAndGetSpills() throws IOException {
if (sorter != null) {
writeSpillFile();
+ freeMemory();
}
return spills.toArray(new SpillInfo[0]);
}