From da65427a94b24b2ad0a6ecf79ac1980dd458c684 Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 7 Jul 2017 23:29:58 +0800 Subject: [PATCH 1/7] Fix DownloadCallback to work well with RetryingBlockFetcher. --- .../shuffle/OneForOneBlockFetcher.java | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index d46ce2e0e6b78..d02c600a51282 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,6 +24,7 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import java.util.UUID; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,7 +122,7 @@ public void onSuccess(ByteBuffer response) { for (int i = 0; i < streamHandle.numChunks; i++) { if (shuffleFiles != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), - new DownloadCallback(shuffleFiles[i], i)); + new DownloadCallback(i)); } else { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } @@ -151,15 +152,27 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } + private static synchronized boolean renameFile(File src, File dest) { + if (dest.exists()) { + if (!dest.delete()) { + return false; + } + } + return src.renameTo(dest); + } + private class DownloadCallback implements StreamCallback { private WritableByteChannel channel = null; private File targetFile = null; + private File tmpFile = null; private int chunkIndex; - DownloadCallback(File targetFile, int chunkIndex) throws IOException { - this.targetFile = targetFile; - this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + DownloadCallback(int chunkIndex) throws IOException { + this.targetFile = shuffleFiles[chunkIndex]; + this.tmpFile = new File(targetFile.getParent(), + targetFile.getName() + "_" + UUID.randomUUID()); + this.channel = Channels.newChannel(new FileOutputStream(tmpFile)); this.chunkIndex = chunkIndex; } @@ -171,14 +184,23 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { channel.close(); + if (!renameFile(tmpFile, targetFile)) { + onFailure(streamId, new Exception("Failed renaming " + tmpFile.getAbsolutePath() + " to " + + targetFile.getAbsolutePath())); + return; + } ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); + tmpFile.delete(); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); } @Override public void onFailure(String streamId, Throwable cause) throws IOException { - channel.close(); + if (channel.isOpen()) { + channel.close(); + } + tmpFile.delete(); // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); From dc7acfed9b3e2336ee75333247f434e5db6645d1 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sat, 8 Jul 2017 16:29:41 +0800 Subject: [PATCH 2/7] Create and delete the files in OneFroOneBlockFetcher --- .../shuffle/ExternalShuffleClient.java | 15 ++++- .../shuffle/OneForOneBlockFetcher.java | 57 ++++++++++--------- .../spark/network/shuffle/ShuffleClient.java | 20 ++++++- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/OneForOneBlockFetcherSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 4 +- .../netty/NettyBlockTransferService.scala | 10 +++- .../storage/ShuffleBlockFetcherIterator.scala | 23 ++------ .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +-- 12 files changed, 85 insertions(+), 62 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6ac9302517ee0..3bb1872a8597b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -20,7 +20,9 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.HashSet; import java.util.List; +import java.util.Set; import com.google.common.collect.Lists; import org.slf4j.Logger; @@ -50,6 +52,7 @@ public class ExternalShuffleClient extends ShuffleClient { private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final long registrationTimeoutMs; + private final Set blockfetchers; protected TransportClientFactory clientFactory; protected String appId; @@ -67,6 +70,7 @@ public ExternalShuffleClient( this.secretKeyHolder = secretKeyHolder; this.authEnabled = authEnabled; this.registrationTimeoutMs = registrationTimeoutMs; + this.blockfetchers = new HashSet<>(); } protected void checkInit() { @@ -91,15 +95,17 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - File[] shuffleFiles) { + boolean toDisk) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, - shuffleFiles).start(); + OneForOneBlockFetcher blockFetcher = new OneForOneBlockFetcher(client, appId, execId, + blockIds1, listener1, conf, toDisk, tmpFileCreater); + blockfetchers.add(blockFetcher); + blockFetcher.start(); }; int maxRetries = conf.maxIORetries(); @@ -142,5 +148,8 @@ public void registerWithShuffleServer( @Override public void close() { clientFactory.close(); + for (OneForOneBlockFetcher blockFetcher : blockfetchers) { + blockFetcher.cleanup(); + } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index d02c600a51282..cd212520234c0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,7 +24,8 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; -import java.util.UUID; +import java.util.HashSet; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -59,10 +60,25 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private TransportConf transportConf = null; - private File[] shuffleFiles = null; + private Boolean toDisk; + private ShuffleClient.TmpFileCreater tmpFileCreater; + + // A set to store the files used for shuffling remote huge blocks. Files in this set will be + // deleted when cleanup. This is a layer of defensiveness against disk file leaks. + private Set shuffleFiles; private StreamHandle streamHandle = null; + public OneForOneBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TransportConf transportConf) { + this(client, appId, execId, blockIds, listener, transportConf, false, null); + } + public OneForOneBlockFetcher( TransportClient client, String appId, @@ -70,18 +86,17 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - File[] shuffleFiles) { + Boolean toDisk, + ShuffleClient.TmpFileCreater tmpFileCreater) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - if (shuffleFiles != null) { - this.shuffleFiles = shuffleFiles; - assert this.shuffleFiles.length == blockIds.length: - "Number of shuffle files should equal to blocks"; - } + this.toDisk = toDisk; + this.tmpFileCreater = tmpFileCreater; + this.shuffleFiles = new HashSet<>(); } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -120,7 +135,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (shuffleFiles != null) { + if (toDisk) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -152,27 +167,22 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } - private static synchronized boolean renameFile(File src, File dest) { - if (dest.exists()) { - if (!dest.delete()) { - return false; - } + public void cleanup() { + for (File file: shuffleFiles) { + file.delete(); } - return src.renameTo(dest); } private class DownloadCallback implements StreamCallback { private WritableByteChannel channel = null; private File targetFile = null; - private File tmpFile = null; private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = shuffleFiles[chunkIndex]; - this.tmpFile = new File(targetFile.getParent(), - targetFile.getName() + "_" + UUID.randomUUID()); - this.channel = Channels.newChannel(new FileOutputStream(tmpFile)); + this.targetFile = tmpFileCreater.createTempBlock(); + shuffleFiles.add(targetFile); + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; } @@ -184,14 +194,8 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { channel.close(); - if (!renameFile(tmpFile, targetFile)) { - onFailure(streamId, new Exception("Failed renaming " + tmpFile.getAbsolutePath() + " to " + - targetFile.getAbsolutePath())); - return; - } ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); - tmpFile.delete(); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); } @@ -200,7 +204,6 @@ public void onFailure(String streamId, Throwable cause) throws IOException { if (channel.isOpen()) { channel.close(); } - tmpFile.delete(); // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 978ff5a2a8699..b348d5bad2a71 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -42,5 +42,23 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - File[] shuffleFiles); + boolean toDisk); + + /** + * Used to create tmp files to shuffle remote blocks to disk. + */ + protected TmpFileCreater tmpFileCreater; + + public void setTmpFileCreaterWhenNull(TmpFileCreater tmpFileCreater) { + if (this.tmpFileCreater == null) { + this.tmpFileCreater = tmpFileCreater; + } + } + + /** + * An interface to provide approach to create tmp file. + */ + public interface TmpFileCreater { + File createTempBlock(); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 8110f1e004c73..02e6eb3a4467e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a6a1b8d0ac3f1..7bc9517d667f7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, null); + }, false); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 61d82214e7d30..dc947a619bf02 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -131,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap { diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 6860214c7fe39..8109b14a8e7da 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit + toDisk: Boolean): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -101,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, shuffleFiles = null) + }, toDisk = false) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b13a9c681e543..a674238902726 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -53,6 +53,7 @@ private[spark] class NettyBlockTransferService( private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) + private val blockFetchers = collection.mutable.HashSet[OneForOneBlockFetcher]() private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -90,14 +91,16 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit = { + toDisk: Boolean): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, - transportConf, shuffleFiles).start() + val blockFetcher = new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, + listener, transportConf, toDisk, tmpFileCreater) + blockFetchers += blockFetcher + blockFetcher.start() } } @@ -158,5 +161,6 @@ private[spark] class NettyBlockTransferService( if (clientFactory != null) { clientFactory.close() } + blockFetchers.foreach(_.cleanup()) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a10f1feadd0af..68945e9a20146 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -131,12 +131,6 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false - /** - * A set to store the files used for shuffling remote huge blocks. Files in this set will be - * deleted when cleanup. This is a layer of defensiveness against disk file leaks. - */ - val shuffleFilesSet = mutable.HashSet[File]() - initialize() // Decrements the buffer reference count. @@ -174,11 +168,6 @@ final class ShuffleBlockFetcherIterator( case _ => } } - shuffleFilesSet.foreach { file => - if (!file.delete()) { - logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); - } - } } private[this] def sendRequest(req: FetchRequest) { @@ -221,15 +210,15 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - val shuffleFiles = blockIds.map { _ => - blockManager.diskBlockManager.createTempLocalBlock()._2 - }.toArray - shuffleFilesSet ++= shuffleFiles + shuffleClient.setTmpFileCreaterWhenNull(new ShuffleClient.TmpFileCreater { + override def createTempBlock(): File = + blockManager.diskBlockManager.createTempLocalBlock()._2 + }) shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, shuffleFiles) + blockFetchingListener, true) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) + blockFetchingListener, false) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 474e30144f629..b6498b34443df 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, null) + }, false) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 086adccea954c..7cd15928f1041 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1382,7 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit = { + toDisk: Boolean): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 559b3faab8fd2..ed44b0610062d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var shuffleFiles: Array[File] = null + var shuffleFiles: Boolean = false when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Boolean] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(shuffleFiles === null) + assert(!shuffleFiles) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(shuffleFiles != null) + assert(shuffleFiles) } } From 6307a6279ed7bf3680c7ae3836128bdfdf228697 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sat, 8 Jul 2017 23:30:12 +0800 Subject: [PATCH 3/7] Pass in tmpFileCreater as a param of ShuffleClient.fetchBlocks --- .../shuffle/ExternalShuffleClient.java | 3 ++- .../shuffle/OneForOneBlockFetcher.java | 4 +--- .../spark/network/shuffle/ShuffleClient.java | 14 ++--------- .../ExternalShuffleIntegrationSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 5 ++-- .../netty/NettyBlockTransferService.scala | 5 ++-- .../storage/ShuffleBlockFetcherIterator.scala | 8 +++---- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 4 +++- .../ShuffleBlockFetcherIteratorSuite.scala | 24 +++++++++---------- 10 files changed, 32 insertions(+), 39 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 3bb1872a8597b..c649b12e15b0a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -95,7 +95,8 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - boolean toDisk) { + boolean toDisk, + TmpFileCreater tmpFileCreater) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index cd212520234c0..06b35a4281f6c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -201,9 +201,7 @@ public void onComplete(String streamId) throws IOException { @Override public void onFailure(String streamId, Throwable cause) throws IOException { - if (channel.isOpen()) { - channel.close(); - } + channel.close(); // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index b348d5bad2a71..43ddc437686bf 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -42,18 +42,8 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - boolean toDisk); - - /** - * Used to create tmp files to shuffle remote blocks to disk. - */ - protected TmpFileCreater tmpFileCreater; - - public void setTmpFileCreaterWhenNull(TmpFileCreater tmpFileCreater) { - if (this.tmpFileCreater == null) { - this.tmpFileCreater = tmpFileCreater; - } - } + boolean toDisk, + TmpFileCreater tmpFileCreater); /** * An interface to provide approach to create tmp file. diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 7bc9517d667f7..e783a529e145a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, false); + }, false, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 8109b14a8e7da..fb3e38efa58dd 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -68,7 +68,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - toDisk: Boolean): Unit + toDisk: Boolean, + tmpFileCreater: ShuffleClient.TmpFileCreater): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -101,7 +102,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, toDisk = false) + }, toDisk = false, tmpFileCreater = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index a674238902726..31f2a8d504575 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -30,7 +30,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, ShuffleClient} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -91,7 +91,8 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - toDisk: Boolean): Unit = { + toDisk: Boolean, + tmpFileCreater: ShuffleClient.TmpFileCreater): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 68945e9a20146..f8fa3d9bb861f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -210,15 +210,15 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.setTmpFileCreaterWhenNull(new ShuffleClient.TmpFileCreater { + val tmpFileCreater = new ShuffleClient.TmpFileCreater { override def createTempBlock(): File = blockManager.diskBlockManager.createTempLocalBlock()._2 - }) + } shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, true) + blockFetchingListener, true, tmpFileCreater) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, false) + blockFetchingListener, false, null) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index b6498b34443df..64476c007cfc3 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, false) + }, false, null) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7cd15928f1041..11ad24d03f682 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -46,6 +46,7 @@ import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.ShuffleClient import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -1382,7 +1383,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - toDisk: Boolean): Unit = { + toDisk: Boolean, + tmpFileCreater: ShuffleClient.TmpFileCreater): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ed44b0610062d..3f94d0c5cfd1d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -138,7 +138,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -157,7 +157,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -224,7 +224,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -289,7 +289,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -328,7 +328,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -370,7 +370,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var shuffleFiles: Boolean = false - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + var toDisk: Boolean = false + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - shuffleFiles = invocation.getArguments()(5).asInstanceOf[Boolean] + toDisk = invocation.getArguments()(5).asInstanceOf[Boolean] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(!shuffleFiles) + assert(!toDisk) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(shuffleFiles) + assert(toDisk) } } From d489ba25bc1b92d2dbc6100ad789e5b512fb13b1 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sun, 9 Jul 2017 16:20:56 +0800 Subject: [PATCH 4/7] delete succeeded files in ShuffleBlockFetcherIterator and delete others in OneForOneBlockFetcher --- .../shuffle/ExternalShuffleClient.java | 18 ++++------- .../shuffle/OneForOneBlockFetcher.java | 30 ++++++++----------- .../spark/network/shuffle/ShuffleClient.java | 11 ++----- .../ExternalShuffleIntegrationSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 6 ++-- .../netty/NettyBlockTransferService.scala | 14 ++++----- .../storage/ShuffleBlockFetcherIterator.scala | 26 ++++++++++++---- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 4 ++- .../ShuffleBlockFetcherIteratorSuite.scala | 16 +++++----- 10 files changed, 64 insertions(+), 65 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index c649b12e15b0a..859a7eaa78a07 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -20,9 +20,8 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.HashSet; import java.util.List; -import java.util.Set; +import java.util.function.Supplier; import com.google.common.collect.Lists; import org.slf4j.Logger; @@ -52,7 +51,6 @@ public class ExternalShuffleClient extends ShuffleClient { private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final long registrationTimeoutMs; - private final Set blockfetchers; protected TransportClientFactory clientFactory; protected String appId; @@ -70,7 +68,6 @@ public ExternalShuffleClient( this.secretKeyHolder = secretKeyHolder; this.authEnabled = authEnabled; this.registrationTimeoutMs = registrationTimeoutMs; - this.blockfetchers = new HashSet<>(); } protected void checkInit() { @@ -96,17 +93,17 @@ public void fetchBlocks( String[] blockIds, BlockFetchingListener listener, boolean toDisk, - TmpFileCreater tmpFileCreater) { + Supplier tmpFileCreater, + Supplier shuffleBlockFetcherIteratorIsZombie) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - OneForOneBlockFetcher blockFetcher = new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, toDisk, tmpFileCreater); - blockfetchers.add(blockFetcher); - blockFetcher.start(); + new OneForOneBlockFetcher(client, appId, execId, + blockIds1, listener1, conf, toDisk, tmpFileCreater, + shuffleBlockFetcherIteratorIsZombie).start(); }; int maxRetries = conf.maxIORetries(); @@ -149,8 +146,5 @@ public void registerWithShuffleServer( @Override public void close() { clientFactory.close(); - for (OneForOneBlockFetcher blockFetcher : blockfetchers) { - blockFetcher.cleanup(); - } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 06b35a4281f6c..8c63d1e83339a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,8 +24,7 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,11 +60,8 @@ public class OneForOneBlockFetcher { private final ChunkReceivedCallback chunkCallback; private TransportConf transportConf = null; private Boolean toDisk; - private ShuffleClient.TmpFileCreater tmpFileCreater; - - // A set to store the files used for shuffling remote huge blocks. Files in this set will be - // deleted when cleanup. This is a layer of defensiveness against disk file leaks. - private Set shuffleFiles; + private Supplier tmpFileCreater; + private Supplier shuffleBlockFetcherIteratorIsZombie; private StreamHandle streamHandle = null; @@ -76,7 +72,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, false, null); + this(client, appId, execId, blockIds, listener, transportConf, false, null, null); } public OneForOneBlockFetcher( @@ -87,7 +83,8 @@ public OneForOneBlockFetcher( BlockFetchingListener listener, TransportConf transportConf, Boolean toDisk, - ShuffleClient.TmpFileCreater tmpFileCreater) { + Supplier tmpFileCreater, + Supplier shuffleBlockFetcherIteratorIsZombie) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; @@ -96,7 +93,7 @@ public OneForOneBlockFetcher( this.transportConf = transportConf; this.toDisk = toDisk; this.tmpFileCreater = tmpFileCreater; - this.shuffleFiles = new HashSet<>(); + this.shuffleBlockFetcherIteratorIsZombie = shuffleBlockFetcherIteratorIsZombie; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -167,12 +164,6 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } - public void cleanup() { - for (File file: shuffleFiles) { - file.delete(); - } - } - private class DownloadCallback implements StreamCallback { private WritableByteChannel channel = null; @@ -180,8 +171,7 @@ private class DownloadCallback implements StreamCallback { private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tmpFileCreater.createTempBlock(); - shuffleFiles.add(targetFile); + this.targetFile = tmpFileCreater.get(); this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; } @@ -197,6 +187,9 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + if (shuffleBlockFetcherIteratorIsZombie.get()) { + targetFile.delete(); + } } @Override @@ -205,6 +198,7 @@ public void onFailure(String streamId, Throwable cause) throws IOException { // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 43ddc437686bf..c286595e68a49 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -19,6 +19,7 @@ import java.io.Closeable; import java.io.File; +import java.util.function.Supplier; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -43,12 +44,6 @@ public abstract void fetchBlocks( String[] blockIds, BlockFetchingListener listener, boolean toDisk, - TmpFileCreater tmpFileCreater); - - /** - * An interface to provide approach to create tmp file. - */ - public interface TmpFileCreater { - File createTempBlock(); - } + Supplier tmpFileCreater, + Supplier shuffleBlockFetcherIteratorIsZombie); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index e783a529e145a..7a16c1fb68c08 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, false, null); + }, false, null, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index fb3e38efa58dd..7ecbc75f86863 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -19,6 +19,7 @@ package org.apache.spark.network import java.io.{Closeable, File} import java.nio.ByteBuffer +import java.util.function.Supplier import scala.concurrent.{Future, Promise} import scala.concurrent.duration.Duration @@ -69,7 +70,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo blockIds: Array[String], listener: BlockFetchingListener, toDisk: Boolean, - tmpFileCreater: ShuffleClient.TmpFileCreater): Unit + tmpFileCreater: Supplier[File], + shuffleBlockFetcherIteratorIsZombie: Supplier[java.lang.Boolean]): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -102,7 +104,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, toDisk = false, tmpFileCreater = null) + }, toDisk = false, tmpFileCreater = null, shuffleBlockFetcherIteratorIsZombie = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 31f2a8d504575..03c4d211cf047 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -19,6 +19,7 @@ package org.apache.spark.network.netty import java.io.File import java.nio.ByteBuffer +import java.util.function.Supplier import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -30,7 +31,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -53,7 +54,6 @@ private[spark] class NettyBlockTransferService( private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) - private val blockFetchers = collection.mutable.HashSet[OneForOneBlockFetcher]() private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -92,16 +92,15 @@ private[spark] class NettyBlockTransferService( blockIds: Array[String], listener: BlockFetchingListener, toDisk: Boolean, - tmpFileCreater: ShuffleClient.TmpFileCreater): Unit = { + tmpFileCreater: Supplier[File], + shuffleBlockFetcherIteratorIsZombie: Supplier[java.lang.Boolean]): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - val blockFetcher = new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, - listener, transportConf, toDisk, tmpFileCreater) - blockFetchers += blockFetcher - blockFetcher.start() + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, + transportConf, toDisk, tmpFileCreater, shuffleBlockFetcherIteratorIsZombie).start() } } @@ -162,6 +161,5 @@ private[spark] class NettyBlockTransferService( if (clientFactory != null) { clientFactory.close() } - blockFetchers.foreach(_.cleanup()) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f8fa3d9bb861f..d7307a77d477b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue +import java.util.function.Supplier import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -131,6 +132,12 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + val shuffleFilesSet = mutable.HashSet[File]() + initialize() // Decrements the buffer reference count. @@ -160,6 +167,7 @@ final class ShuffleBlockFetcherIterator( if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleFilesSet += buf.asInstanceOf[FileSegmentManagedBuffer].getFile shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } shuffleMetrics.incRemoteBlocksFetched(1) @@ -168,6 +176,11 @@ final class ShuffleBlockFetcherIterator( case _ => } } + shuffleFilesSet.foreach { file => + if (!file.delete()) { + logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) + } + } } private[this] def sendRequest(req: FetchRequest) { @@ -210,15 +223,15 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - val tmpFileCreater = new ShuffleClient.TmpFileCreater { - override def createTempBlock(): File = - blockManager.diskBlockManager.createTempLocalBlock()._2 - } shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, true, tmpFileCreater) + blockFetchingListener, true, new Supplier[File] { + override def get(): File = blockManager.diskBlockManager.createTempLocalBlock()._2 + }, new Supplier[java.lang.Boolean] { + override def get(): java.lang.Boolean = isZombie + }) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, false, null) + blockFetchingListener, false, null, null) } } @@ -356,6 +369,7 @@ final class ShuffleBlockFetcherIterator( if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleFilesSet += buf.asInstanceOf[FileSegmentManagedBuffer].getFile shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } shuffleMetrics.incRemoteBlocksFetched(1) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 64476c007cfc3..93d7826fa6dd1 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, false, null) + }, false, null, null) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 11ad24d03f682..a14c2ad04a1de 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.File import java.nio.ByteBuffer +import java.util.function.Supplier import scala.collection.JavaConverters._ import scala.collection.mutable @@ -1384,7 +1385,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE blockIds: Array[String], listener: BlockFetchingListener, toDisk: Boolean, - tmpFileCreater: ShuffleClient.TmpFileCreater): Unit = { + tmpFileCreater: Supplier[File], + shuffleBlockFetcherIteratorIsZombie: Supplier[java.lang.Boolean]): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 3f94d0c5cfd1d..c959ce00e24fd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -138,7 +138,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -157,7 +157,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -224,7 +224,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -289,7 +289,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -328,7 +328,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -370,7 +370,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -433,7 +433,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var toDisk: Boolean = false - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] From ef9f994ab51c43583b0160e5451c256498fd6fec Mon Sep 17 00:00:00 2001 From: jinxing Date: Sun, 9 Jul 2017 22:30:23 +0800 Subject: [PATCH 5/7] Remove toDisk and rename shuffleBlockFetcherIteratorIsZombie to be canCallerSideDeleteFile --- .../shuffle/ExternalShuffleClient.java | 7 +++--- .../shuffle/OneForOneBlockFetcher.java | 17 ++++++------- .../spark/network/shuffle/ShuffleClient.java | 3 +-- .../ExternalShuffleIntegrationSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 5 ++-- .../netty/NettyBlockTransferService.scala | 5 ++-- .../storage/ShuffleBlockFetcherIterator.scala | 4 +-- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 3 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 25 ++++++++++--------- 10 files changed, 34 insertions(+), 39 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 859a7eaa78a07..dbf183ee6ed05 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -92,9 +92,8 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - boolean toDisk, Supplier tmpFileCreater, - Supplier shuffleBlockFetcherIteratorIsZombie) { + Supplier canCallerSideDeleteFile) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -102,8 +101,8 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, toDisk, tmpFileCreater, - shuffleBlockFetcherIteratorIsZombie).start(); + blockIds1, listener1, conf, tmpFileCreater, + canCallerSideDeleteFile).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 8c63d1e83339a..641309c20c05a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -59,9 +59,8 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private TransportConf transportConf = null; - private Boolean toDisk; private Supplier tmpFileCreater; - private Supplier shuffleBlockFetcherIteratorIsZombie; + private Supplier canCallerSideDeleteFile; private StreamHandle streamHandle = null; @@ -72,7 +71,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, false, null, null); + this(client, appId, execId, blockIds, listener, transportConf, null, null); } public OneForOneBlockFetcher( @@ -82,18 +81,18 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - Boolean toDisk, Supplier tmpFileCreater, - Supplier shuffleBlockFetcherIteratorIsZombie) { + Supplier canCallerSideDeleteFile) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.toDisk = toDisk; this.tmpFileCreater = tmpFileCreater; - this.shuffleBlockFetcherIteratorIsZombie = shuffleBlockFetcherIteratorIsZombie; + this.canCallerSideDeleteFile = canCallerSideDeleteFile; + assert (tmpFileCreater == null && canCallerSideDeleteFile == null || + tmpFileCreater != null && canCallerSideDeleteFile != null); } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -132,7 +131,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (toDisk) { + if (tmpFileCreater != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -187,7 +186,7 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (shuffleBlockFetcherIteratorIsZombie.get()) { + if (canCallerSideDeleteFile.get()) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index c286595e68a49..9f33294311d49 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -43,7 +43,6 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - boolean toDisk, Supplier tmpFileCreater, - Supplier shuffleBlockFetcherIteratorIsZombie); + Supplier canCallerSideDeleteFile); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 7a16c1fb68c08..2aaa04def24d4 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, false, null, null); + }, null, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 7ecbc75f86863..f4cf5fed6c367 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -69,9 +69,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - toDisk: Boolean, tmpFileCreater: Supplier[File], - shuffleBlockFetcherIteratorIsZombie: Supplier[java.lang.Boolean]): Unit + canCallerSideDeleteFile: Supplier[java.lang.Boolean]): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -104,7 +103,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, toDisk = false, tmpFileCreater = null, shuffleBlockFetcherIteratorIsZombie = null) + }, tmpFileCreater = null, canCallerSideDeleteFile = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 03c4d211cf047..16fa3f7db799a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -91,16 +91,15 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - toDisk: Boolean, tmpFileCreater: Supplier[File], - shuffleBlockFetcherIteratorIsZombie: Supplier[java.lang.Boolean]): Unit = { + canCallerSideDeleteFile: Supplier[java.lang.Boolean]): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, toDisk, tmpFileCreater, shuffleBlockFetcherIteratorIsZombie).start() + transportConf, tmpFileCreater, canCallerSideDeleteFile).start() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d7307a77d477b..340702c65ee33 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -224,14 +224,14 @@ final class ShuffleBlockFetcherIterator( // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, true, new Supplier[File] { + blockFetchingListener, new Supplier[File] { override def get(): File = blockManager.diskBlockManager.createTempLocalBlock()._2 }, new Supplier[java.lang.Boolean] { override def get(): java.lang.Boolean = isZombie }) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, false, null, null) + blockFetchingListener, null, null) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 93d7826fa6dd1..e443fa2c98a2f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, false, null, null) + }, null, null) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a14c2ad04a1de..920aac90aadcf 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1384,9 +1384,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - toDisk: Boolean, tmpFileCreater: Supplier[File], - shuffleBlockFetcherIteratorIsZombie: Supplier[java.lang.Boolean]): Unit = { + canCallerSideDeleteFile: Supplier[java.lang.Boolean]): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c959ce00e24fd..c4d359f66f3a1 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} import java.util.UUID import java.util.concurrent.Semaphore +import java.util.function.Supplier import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -46,7 +47,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -138,7 +139,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -157,7 +158,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -224,7 +225,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -289,7 +290,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -328,7 +329,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -370,7 +371,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -432,12 +433,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var toDisk: Boolean = false - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any(), any())) + var tmpFileCreater: Supplier[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - toDisk = invocation.getArguments()(5).asInstanceOf[Boolean] + tmpFileCreater = invocation.getArguments()(5).asInstanceOf[Supplier[File]] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -466,13 +467,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(!toDisk) + assert(tmpFileCreater == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(toDisk) + assert(tmpFileCreater != null) } } From bb82dd390ad63e5ebf93e543d903e76e7308b5ef Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 9 Jul 2017 21:39:58 -0700 Subject: [PATCH 6/7] Fix a potential file leak due to the race condition and refactor (#1) --- .../shuffle/ExternalShuffleClient.java | 8 ++--- .../shuffle/OneForOneBlockFetcher.java | 22 +++++------- .../spark/network/shuffle/ShuffleClient.java | 15 +++++--- .../shuffle/TempShuffleFileManager.java | 36 +++++++++++++++++++ .../spark/network/BlockTransferService.scala | 10 +++--- .../netty/NettyBlockTransferService.scala | 9 ++--- .../storage/ShuffleBlockFetcherIterator.scala | 33 ++++++++++------- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 7 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 27 +++++++------- 10 files changed, 100 insertions(+), 69 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index dbf183ee6ed05..31bd24e5038b2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,11 +17,9 @@ package org.apache.spark.network.shuffle; -import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; -import java.util.function.Supplier; import com.google.common.collect.Lists; import org.slf4j.Logger; @@ -92,8 +90,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - Supplier tmpFileCreater, - Supplier canCallerSideDeleteFile) { + TempShuffleFileManager tempShuffleFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -101,8 +98,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tmpFileCreater, - canCallerSideDeleteFile).start(); + blockIds1, listener1, conf, tempShuffleFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 641309c20c05a..2f160d12af22b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,7 +24,6 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,9 +57,8 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private TransportConf transportConf = null; - private Supplier tmpFileCreater; - private Supplier canCallerSideDeleteFile; + private final TransportConf transportConf; + private final TempShuffleFileManager tempShuffleFileManager; private StreamHandle streamHandle = null; @@ -71,7 +69,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, null, null); + this(client, appId, execId, blockIds, listener, transportConf, null); } public OneForOneBlockFetcher( @@ -81,18 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - Supplier tmpFileCreater, - Supplier canCallerSideDeleteFile) { + TempShuffleFileManager tempShuffleFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.tmpFileCreater = tmpFileCreater; - this.canCallerSideDeleteFile = canCallerSideDeleteFile; - assert (tmpFileCreater == null && canCallerSideDeleteFile == null || - tmpFileCreater != null && canCallerSideDeleteFile != null); + this.tempShuffleFileManager = tempShuffleFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -131,7 +125,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tmpFileCreater != null) { + if (tempShuffleFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -170,7 +164,7 @@ private class DownloadCallback implements StreamCallback { private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tmpFileCreater.get(); + this.targetFile = tempShuffleFileManager.createTempShuffleFile(); this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; } @@ -186,7 +180,7 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (canCallerSideDeleteFile.get()) { + if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 9f33294311d49..9e77bee7f9ee6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -18,8 +18,6 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; -import java.io.File; -import java.util.function.Supplier; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -36,6 +34,16 @@ public void init(String appId) { } * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param execId the executor id. + * @param blockIds block ids to fetch. + * @param listener the listener to receive block fetching status. + * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -43,6 +51,5 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - Supplier tmpFileCreater, - Supplier canCallerSideDeleteFile); + TempShuffleFileManager tempShuffleFileManager); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java new file mode 100644 index 0000000000000..84a5ed6a276bd --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; + +/** + * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * files when they won't be used any more. + */ +public interface TempShuffleFileManager { + + /** Create a temp shuffle block file. */ + File createTempShuffleFile(); + + /** + * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * file is registered successfully. If `false`, the caller should clean up the file by itself. + */ + boolean registerTempShuffleFileToClean(File file); +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index f4cf5fed6c367..fe5fd2da039bb 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,9 +17,8 @@ package org.apache.spark.network -import java.io.{Closeable, File} +import java.io.Closeable import java.nio.ByteBuffer -import java.util.function.Supplier import scala.concurrent.{Future, Promise} import scala.concurrent.duration.Duration @@ -27,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -69,8 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tmpFileCreater: Supplier[File], - canCallerSideDeleteFile: Supplier[java.lang.Boolean]): Unit + tempShuffleFileManager: TempShuffleFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -103,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, tmpFileCreater = null, canCallerSideDeleteFile = null) + }, tempShuffleFileManager = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 16fa3f7db799a..30ff93897f98a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,9 +17,7 @@ package org.apache.spark.network.netty -import java.io.File import java.nio.ByteBuffer -import java.util.function.Supplier import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -31,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -91,15 +89,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tmpFileCreater: Supplier[File], - canCallerSideDeleteFile: Supplier[java.lang.Boolean]): Unit = { + tempShuffleFileManager: TempShuffleFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tmpFileCreater, canCallerSideDeleteFile).start() + transportConf, tempShuffleFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 340702c65ee33..81d822dc8a98f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,7 +20,6 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue -import java.util.function.Supplier import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -29,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -67,7 +66,7 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with Logging { + extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -136,7 +135,8 @@ final class ShuffleBlockFetcherIterator( * A set to store the files used for shuffling remote huge blocks. Files in this set will be * deleted when cleanup. This is a layer of defensiveness against disk file leaks. */ - val shuffleFilesSet = mutable.HashSet[File]() + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[File]() initialize() @@ -150,6 +150,19 @@ final class ShuffleBlockFetcherIterator( currentResult = null } + override def createTempShuffleFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + /** * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ @@ -167,7 +180,6 @@ final class ShuffleBlockFetcherIterator( if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { - shuffleFilesSet += buf.asInstanceOf[FileSegmentManagedBuffer].getFile shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } shuffleMetrics.incRemoteBlocksFetched(1) @@ -178,7 +190,7 @@ final class ShuffleBlockFetcherIterator( } shuffleFilesSet.foreach { file => if (!file.delete()) { - logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) + logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) } } } @@ -224,14 +236,10 @@ final class ShuffleBlockFetcherIterator( // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, new Supplier[File] { - override def get(): File = blockManager.diskBlockManager.createTempLocalBlock()._2 - }, new Supplier[java.lang.Boolean] { - override def get(): java.lang.Boolean = isZombie - }) + blockFetchingListener, this) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null, null) + blockFetchingListener, null) } } @@ -369,7 +377,6 @@ final class ShuffleBlockFetcherIterator( if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { - shuffleFilesSet += buf.asInstanceOf[FileSegmentManagedBuffer].getFile shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } shuffleMetrics.incRemoteBlocksFetched(1) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index e443fa2c98a2f..474e30144f629 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, null, null) + }, null) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 920aac90aadcf..755a61a438a6a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import java.io.File import java.nio.ByteBuffer -import java.util.function.Supplier import scala.collection.JavaConverters._ import scala.collection.mutable @@ -46,8 +45,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.network.shuffle.ShuffleClient +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -1384,8 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tmpFileCreater: Supplier[File], - canCallerSideDeleteFile: Supplier[java.lang.Boolean]): Unit = { + tempShuffleFileManager: TempShuffleFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c4d359f66f3a1..6a70cedf769b8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} import java.util.UUID import java.util.concurrent.Semaphore -import java.util.function.Supplier import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -34,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -47,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -139,7 +138,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -158,7 +157,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -225,7 +224,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -290,7 +289,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -329,7 +328,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -371,7 +370,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -433,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tmpFileCreater: Supplier[File] = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + var tempShuffleFileManager: TempShuffleFileManager = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tmpFileCreater = invocation.getArguments()(5).asInstanceOf[Supplier[File]] + tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -467,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tmpFileCreater == null) + assert(tempShuffleFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tmpFileCreater != null) + assert(tempShuffleFileManager != null) } } From 15001a6dac4b02ef408a7e800a146e776a204435 Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 10 Jul 2017 12:59:39 +0800 Subject: [PATCH 7/7] fix unit test. --- .../spark/network/shuffle/ExternalShuffleIntegrationSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 2aaa04def24d4..a6a1b8d0ac3f1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, null, null); + }, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server");