Skip to content

Commit

Permalink
[CORE] Updates to remote cache reads
Browse files Browse the repository at this point in the history
Covered by tests in DistributedSuite
  • Loading branch information
squito committed Sep 13, 2018
1 parent 6d742d1 commit 575fea1
Show file tree
Hide file tree
Showing 15 changed files with 299 additions and 62 deletions.
Expand Up @@ -36,7 +36,10 @@
*/
public abstract class ManagedBuffer {

/** Number of bytes of the data. */
/**
* Number of bytes of the data. If this buffer will decrypt for all of the views into the data,
* this is the size of the decrypted data.
*/
public abstract long size();

/**
Expand Down
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle;

import java.io.IOException;

/**
* A handle on the file used when fetching remote data to disk. Used to ensure the lifecycle of
* writing the data, reading it back, and then cleaning it up is followed. Specific implementations
* may also handle encryption. The data can be read only via DownloadFileWritableChannel,
* which ensures data is not read until after the writer is closed.
*/
public interface DownloadFile {
/**
* Delete the file.
*
* @return <code>true</code> if and only if the file or directory is
* successfully deleted; <code>false</code> otherwise
*/
public boolean delete();

/**
* A channel for writing data to the file. This special channel allows access to the data for
* reading, after the channel is closed, via {@link DownloadFileWritableChannel#closeAndRead()}.
*/
public DownloadFileWritableChannel openForWriting() throws IOException;

/**
* The path of the file, intended only for debug purposes.
*/
public String path();
}
Expand Up @@ -17,20 +17,20 @@

package org.apache.spark.network.shuffle;

import java.io.File;
import org.apache.spark.network.util.TransportConf;

/**
* A manager to create temp block files to reduce the memory usage and also clean temp
* files when they won't be used any more.
* A manager to create temp block files used when fetching remote data to reduce the memory usage.
* It will clean files when they won't be used any more.
*/
public interface TempFileManager {
public interface DownloadFileManager {

/** Create a temp block file. */
File createTempFile();
DownloadFile createTempFile(TransportConf transportConf);

/**
* Register a temp file to clean up when it won't be used any more. Return whether the
* file is registered successfully. If `false`, the caller should clean up the file by itself.
*/
boolean registerTempFileToClean(File file);
boolean registerTempFileToClean(DownloadFile file);
}
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle;

import org.apache.spark.network.buffer.ManagedBuffer;

import java.io.OutputStream;
import java.nio.channels.WritableByteChannel;

/**
* A channel for writing data which is fetched to disk, which allows access to the written data only
* after the writer has been closed. Used with DownloadFile and DownloadFileManager.
*/
public interface DownloadFileWritableChannel extends WritableByteChannel {
public ManagedBuffer closeAndRead();
}
Expand Up @@ -91,15 +91,15 @@ public void fetchBlocks(
String execId,
String[] blockIds,
BlockFetchingListener listener,
TempFileManager tempFileManager) {
DownloadFileManager downloadFileManager) {
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, appId, execId,
blockIds1, listener1, conf, tempFileManager).start();
blockIds1, listener1, conf, downloadFileManager).start();
};

int maxRetries = conf.maxIORetries();
Expand Down
Expand Up @@ -17,18 +17,13 @@

package org.apache.spark.network.shuffle;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.Arrays;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
Expand Down Expand Up @@ -58,7 +53,7 @@ public class OneForOneBlockFetcher {
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
private final TransportConf transportConf;
private final TempFileManager tempFileManager;
private final DownloadFileManager downloadFileManager;

private StreamHandle streamHandle = null;

Expand All @@ -79,14 +74,14 @@ public OneForOneBlockFetcher(
String[] blockIds,
BlockFetchingListener listener,
TransportConf transportConf,
TempFileManager tempFileManager) {
DownloadFileManager downloadFileManager) {
this.client = client;
this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf;
this.tempFileManager = tempFileManager;
this.downloadFileManager = downloadFileManager;
}

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
Expand Down Expand Up @@ -125,7 +120,7 @@ public void onSuccess(ByteBuffer response) {
// Immediately request all chunks -- we expect that the total size of the request is
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
for (int i = 0; i < streamHandle.numChunks; i++) {
if (tempFileManager != null) {
if (downloadFileManager != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(i));
} else {
Expand Down Expand Up @@ -159,13 +154,13 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {

private class DownloadCallback implements StreamCallback {

private WritableByteChannel channel = null;
private File targetFile = null;
private DownloadFileWritableChannel channel = null;
private DownloadFile targetFile = null;
private int chunkIndex;

DownloadCallback(int chunkIndex) throws IOException {
this.targetFile = tempFileManager.createTempFile();
this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.targetFile = downloadFileManager.createTempFile(transportConf);
this.channel = targetFile.openForWriting();
this.chunkIndex = chunkIndex;
}

Expand All @@ -178,11 +173,8 @@ public void onData(String streamId, ByteBuffer buf) throws IOException {

@Override
public void onComplete(String streamId) throws IOException {
channel.close();
ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
targetFile.length());
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
if (!tempFileManager.registerTempFileToClean(targetFile)) {
listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead());
if (!downloadFileManager.registerTempFileToClean(targetFile)) {
targetFile.delete();
}
}
Expand Down
Expand Up @@ -43,7 +43,7 @@ public void init(String appId) { }
* @param execId the executor id.
* @param blockIds block ids to fetch.
* @param listener the listener to receive block fetching status.
* @param tempFileManager TempFileManager to create and clean temp files.
* @param downloadFileManager DownloadFileManager to create and clean temp files.
* If it's not <code>null</code>, the remote blocks will be streamed
* into temp shuffle files to reduce the memory usage, otherwise,
* they will be kept in memory.
Expand All @@ -54,7 +54,7 @@ public abstract void fetchBlocks(
String execId,
String[] blockIds,
BlockFetchingListener listener,
TempFileManager tempFileManager);
DownloadFileManager downloadFileManager);

/**
* Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to
Expand Down
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;

import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.util.TransportConf;

/**
* A DownloadFile that does not take any encryption settings into account for reading and
* writing data.
*
* This does *not* mean the data in the file is un-encrypted -- it could be that the data is
* already encrypted when its written, and subsequent layer is responsible for decrypting.
*/
public class SimpleDownloadFile implements DownloadFile {

private final File file;
private final TransportConf transportConf;

public SimpleDownloadFile(File file, TransportConf transportConf) {
this.file = file;
this.transportConf = transportConf;
}

@Override
public boolean delete() {
return file.delete();
}

@Override
public DownloadFileWritableChannel openForWriting() throws IOException {
return new SimpleDownloadWritableChannel();
}

@Override
public String path() {
return file.getAbsolutePath();
}

private class SimpleDownloadWritableChannel implements DownloadFileWritableChannel {

private final WritableByteChannel channel;

SimpleDownloadWritableChannel() throws FileNotFoundException {
channel = Channels.newChannel(new FileOutputStream(file));
}

@Override
public ManagedBuffer closeAndRead() {
return new FileSegmentManagedBuffer(transportConf, file, 0, file.length());
}

@Override
public int write(ByteBuffer src) throws IOException {
return channel.write(src);
}

@Override
public boolean isOpen() {
return channel.isOpen();
}

@Override
public void close() throws IOException {
channel.close();
}
}
}
Expand Up @@ -26,7 +26,7 @@ import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils

Expand Down Expand Up @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
tempFileManager: TempFileManager): Unit
tempFileManager: DownloadFileManager): Unit

/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
Expand All @@ -92,7 +92,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
port: Int,
execId: String,
blockId: String,
tempFileManager: TempFileManager): ManagedBuffer = {
tempFileManager: DownloadFileManager): ManagedBuffer = {
// A monitor for the thread to wait on.
val result = Promise[ManagedBuffer]()
fetchBlocks(host, port, execId, Array(blockId),
Expand Down
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
Expand Down Expand Up @@ -105,7 +105,7 @@ private[spark] class NettyBlockTransferService(
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
tempFileManager: TempFileManager): Unit = {
tempFileManager: DownloadFileManager): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
Expand Down

0 comments on commit 575fea1

Please sign in to comment.