diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 8f7c3b4232691..6ca3e09e3e439 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -252,6 +252,22 @@ private long freeMemory() { return memoryFreed; } + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupAfterError() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (spillingEnabled && sorter != null) { + shuffleMemoryManager.release(sorter.getMemoryUsage()); + sorter = null; + } + } + /** * Checks whether there is enough space to insert a new record into the sorter. * @@ -362,11 +378,16 @@ public void insertRecord( * @throws IOException */ public SpillInfo[] closeAndGetSpills() throws IOException { - if (sorter != null) { - writeSpillFile(); - freeMemory(); + try { + if (sorter != null) { + writeSpillFile(); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupAfterError(); + throw e; } - return spills.toArray(new SpillInfo[0]); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index e5a942498ae00..70afea553556c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,10 +17,7 @@ package org.apache.spark.shuffle.unsafe; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.Iterator; @@ -34,6 +31,7 @@ import com.esotericsoftware.kryo.io.ByteBufferOutputStream; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -152,16 +150,22 @@ void closeAndWriteOutput() throws IOException { serArray = null; serByteBuffer = null; serOutputStream = null; - final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills()); + final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - private void freeMemory() { - // TODO - } - @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException{ if (sorter == null) { @@ -241,17 +245,10 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th } } } finally { - for (int i = 0; i < spills.length; i++) { - if (spillInputStreams[i] != null) { - spillInputStreams[i].close(); - if (!spills[i].file.delete()) { - logger.error("Error while deleting spill file {}", spills[i]); - } - } - } - if (mergedFileOutputStream != null) { - mergedFileOutputStream.close(); + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, false); } + Closeables.close(mergedFileOutputStream, false); } return partitionLengths; } @@ -305,16 +302,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th } finally { for (int i = 0; i < spills.length; i++) { assert(spillInputChannelPositions[i] == spills[i].file.length()); - if (spillInputChannels[i] != null) { - spillInputChannels[i].close(); - if (!spills[i].file.delete()) { - logger.error("Error while deleting spill file {}", spills[i]); - } - } - } - if (mergedFileOutputChannel != null) { - mergedFileOutputChannel.close(); + Closeables.close(spillInputChannels[i], false); } + Closeables.close(mergedFileOutputChannel, false); } return partitionLengths; } @@ -326,7 +316,6 @@ public Option stop(boolean success) { return Option.apply(null); } else { stopping = true; - freeMemory(); if (success) { if (mapStatus == null) { throw new IllegalStateException("Cannot call stop(true) without having called write()"); @@ -339,7 +328,11 @@ public Option stop(boolean success) { } } } finally { - freeMemory(); + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupAfterError(); + } } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 09eb537c04367..a1d654c9d121e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -57,7 +57,7 @@ public class UnsafeShuffleWriterSuite { static final int NUM_PARTITITONS = 4; - final TaskMemoryManager memoryManager = + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); File mergedOutputFile; @@ -82,6 +82,10 @@ public OutputStream apply(OutputStream stream) { @After public void tearDown() { Utils.deleteRecursively(tempDir); + final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (leakedMemory != 0) { + Assert.fail("Test leaked " + leakedMemory + " bytes of managed memory"); + } } @Before @@ -154,7 +158,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl return new UnsafeShuffleWriter( blockManager, shuffleBlockManager, - memoryManager, + taskMemoryManager, shuffleMemoryManager, new UnsafeShuffleHandle(0, 1, shuffleDep), 0, // map id @@ -216,7 +220,7 @@ public void writeWithoutSpilling() throws Exception { } private void testMergingSpills(boolean transferToEnabled) throws IOException { - final UnsafeShuffleWriter writer = createWriter(true); + final UnsafeShuffleWriter writer = createWriter(transferToEnabled); writer.insertRecordIntoSorter(new Tuple2(1, 1)); writer.insertRecordIntoSorter(new Tuple2(2, 2)); writer.insertRecordIntoSorter(new Tuple2(3, 3)); @@ -249,8 +253,17 @@ public void mergeSpillsWithFileStream() throws Exception { testMergingSpills(false); } + @Test + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + final UnsafeShuffleWriter writer = createWriter(false); + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.stop(false); + assertSpillFilesWereCleanedUp(); + } + // TODO: actually try to read the shuffle output? - // TODO: add a test that manually triggers spills in order to exercise the merging. -// } }