Skip to content

Commit

Permalink
[SPARK-21175] Reject OpenBlocks when memory shortage on shuffle service.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

A shuffle service can serves blocks from multiple apps/tasks. Thus the shuffle service can suffers high memory usage when lots of shuffle-reads happen at the same time. In my cluster, OOM always happens on shuffle service. Analyzing heap dump, memory cost by Netty(ChannelOutboundBufferEntry) can be up to 2~3G. It might make sense to reject "open blocks" request when memory usage is high on shuffle service.

93dd0c5 and 85c6ce6 tried to alleviate the memory pressure on shuffle service but cannot solve the root cause. This pr proposes to control currency of shuffle read.

## How was this patch tested?
Added unit test.

Author: jinxing <jinxing6042@126.com>

Closes #18388 from jinxing64/SPARK-21175.
  • Loading branch information
jinxing authored and cloud-fan committed Jul 25, 2017
1 parent 996a809 commit 799e131
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 13 deletions.
Expand Up @@ -168,7 +168,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler);
rpcHandler, conf.maxChunksBeingTransferred());
return new TransportChannelHandler(client, responseHandler, requestHandler,
conf.connectionTimeoutMs(), closeIdleConnections);
}
Expand Down
Expand Up @@ -25,6 +25,8 @@

import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -53,6 +55,9 @@ private static class StreamState {
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;

// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;

StreamState(String appId, Iterator<ManagedBuffer> buffers) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
Expand Down Expand Up @@ -96,18 +101,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {

@Override
public ManagedBuffer openStream(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2:
"Stream id and chunk index should be specified when open stream for fetching block.";
long streamId = Long.valueOf(array[0]);
int chunkIndex = Integer.valueOf(array[1]);
return getChunk(streamId, chunkIndex);
Pair<Long, Integer> streamChunkIdPair = parseStreamChunkId(streamChunkId);
return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight());
}

public static String genStreamChunkId(long streamId, int chunkId) {
return String.format("%d_%d", streamId, chunkId);
}

// Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a
// stream.
public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2:
"Stream id and chunk index should be specified.";
long streamId = Long.valueOf(array[0]);
int chunkIndex = Integer.valueOf(array[1]);
return ImmutablePair.of(streamId, chunkIndex);
}

@Override
public void connectionTerminated(Channel channel) {
// Close all streams which have been associated with the channel.
Expand Down Expand Up @@ -139,6 +151,42 @@ public void checkAuthorization(TransportClient client, long streamId) {
}
}

@Override
public void chunkBeingSent(long streamId) {
StreamState streamState = streams.get(streamId);
if (streamState != null) {
streamState.chunksBeingTransferred++;
}

}

@Override
public void streamBeingSent(String streamId) {
chunkBeingSent(parseStreamChunkId(streamId).getLeft());
}

@Override
public void chunkSent(long streamId) {
StreamState streamState = streams.get(streamId);
if (streamState != null) {
streamState.chunksBeingTransferred--;
}
}

@Override
public void streamSent(String streamId) {
chunkSent(OneForOneStreamManager.parseStreamChunkId(streamId).getLeft());
}

@Override
public long chunksBeingTransferred() {
long sum = 0L;
for (StreamState streamState: streams.values()) {
sum += streamState.chunksBeingTransferred;
}
return sum;
}

/**
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
Expand Down
Expand Up @@ -83,4 +83,31 @@ public void connectionTerminated(Channel channel) { }
*/
public void checkAuthorization(TransportClient client, long streamId) { }

/**
* Return the number of chunks being transferred and not finished yet in this StreamManager.
*/
public long chunksBeingTransferred() {
return 0;
}

/**
* Called when start sending a chunk.
*/
public void chunkBeingSent(long streamId) { }

/**
* Called when start sending a stream.
*/
public void streamBeingSent(String streamId) { }

/**
* Called when a chunk is successfully sent.
*/
public void chunkSent(long streamId) { }

/**
* Called when a stream is successfully sent.
*/
public void streamSent(String streamId) { }

}
Expand Up @@ -22,6 +22,7 @@

import com.google.common.base.Throwables;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -65,14 +66,19 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Returns each chunk part of a stream. */
private final StreamManager streamManager;

/** The max number of chunks being transferred and not finished yet. */
private final long maxChunksBeingTransferred;

public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
RpcHandler rpcHandler) {
RpcHandler rpcHandler,
Long maxChunksBeingTransferred) {
this.channel = channel;
this.reverseClient = reverseClient;
this.rpcHandler = rpcHandler;
this.streamManager = rpcHandler.getStreamManager();
this.maxChunksBeingTransferred = maxChunksBeingTransferred;
}

@Override
Expand Down Expand Up @@ -117,7 +123,13 @@ private void processFetchRequest(final ChunkFetchRequest req) {
logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel),
req.streamChunkId);
}

long chunksBeingTransferred = streamManager.chunksBeingTransferred();
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
chunksBeingTransferred, maxChunksBeingTransferred);
channel.close();
return;
}
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
Expand All @@ -130,10 +142,25 @@ private void processFetchRequest(final ChunkFetchRequest req) {
return;
}

respond(new ChunkFetchSuccess(req.streamChunkId, buf));
streamManager.chunkBeingSent(req.streamChunkId.streamId);
respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> {
streamManager.chunkSent(req.streamChunkId.streamId);
});
}

private void processStreamRequest(final StreamRequest req) {
if (logger.isTraceEnabled()) {
logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel),
req.streamId);
}

long chunksBeingTransferred = streamManager.chunksBeingTransferred();
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
chunksBeingTransferred, maxChunksBeingTransferred);
channel.close();
return;
}
ManagedBuffer buf;
try {
buf = streamManager.openStream(req.streamId);
Expand All @@ -145,7 +172,10 @@ private void processStreamRequest(final StreamRequest req) {
}

if (buf != null) {
respond(new StreamResponse(req.streamId, buf.size(), buf));
streamManager.streamBeingSent(req.streamId);
respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> {
streamManager.streamSent(req.streamId);
});
} else {
respond(new StreamFailure(req.streamId, String.format(
"Stream '%s' was not found.", req.streamId)));
Expand Down Expand Up @@ -187,9 +217,9 @@ private void processOneWayMessage(OneWayMessage req) {
* Responds to a single message with some Encodable object. If a failure occurs while sending,
* it will be logged and the channel closed.
*/
private void respond(Encodable result) {
private ChannelFuture respond(Encodable result) {
SocketAddress remoteAddress = channel.remoteAddress();
channel.writeAndFlush(result).addListener(future -> {
return channel.writeAndFlush(result).addListener(future -> {
if (future.isSuccess()) {
logger.trace("Sent result {} to client {}", result, remoteAddress);
} else {
Expand Down
Expand Up @@ -257,4 +257,10 @@ public Properties cryptoConf() {
return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll());
}

/**
* The max number of chunks allowed to being transferred at the same time on shuffle service.
*/
public long maxChunksBeingTransferred() {
return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE);
}
}
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network;

import java.util.ArrayList;
import java.util.List;

import io.netty.channel.Channel;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.junit.Test;

import static org.mockito.Mockito.*;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.*;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportRequestHandler;

public class TransportRequestHandlerSuite {

@Test
public void handleFetchRequestAndStreamRequest() throws Exception {
RpcHandler rpcHandler = new NoOpRpcHandler();
OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
Channel channel = mock(Channel.class);
List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs =
new ArrayList<>();
when(channel.writeAndFlush(any()))
.thenAnswer(invocationOnMock0 -> {
Object response = invocationOnMock0.getArguments()[0];
ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
return channelFuture;
});

// Prepare the stream.
List<ManagedBuffer> managedBuffers = new ArrayList<>();
managedBuffers.add(new TestManagedBuffer(10));
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
streamManager.registerChannel(channel, streamId);
TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);

RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0));
requestHandler.handle(request0);
assert responseAndPromisePairs.size() == 1;
assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() ==
managedBuffers.get(0);

RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1));
requestHandler.handle(request1);
assert responseAndPromisePairs.size() == 2;
assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() ==
managedBuffers.get(1);

// Finish flushing the response for request0.
responseAndPromisePairs.get(0).getRight().finish(true);

RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
requestHandler.handle(request2);
assert responseAndPromisePairs.size() == 3;
assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse;
assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() ==
managedBuffers.get(2);

// Request3 will trigger the close of channel, because the number of max chunks being
// transferred is 2;
RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3));
requestHandler.handle(request3);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;
}

private class ExtendedChannelPromise extends DefaultChannelPromise {

private List<GenericFutureListener> listeners = new ArrayList<>();
private boolean success;

public ExtendedChannelPromise(Channel channel) {
super(channel);
success = false;
}

@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
listeners.add(listener);
return super.addListener(listener);
}

@Override
public boolean isSuccess() {
return success;
}

public void finish(boolean success) {
this.success = success;
listeners.forEach(listener -> {
try {
listener.operationComplete(this);
} catch (Exception e) { }
});
}
}
}
7 changes: 7 additions & 0 deletions docs/configuration.md
Expand Up @@ -631,6 +631,13 @@ Apart from these, the following properties are also available, and may be useful
Max number of entries to keep in the index cache of the shuffle service.
</td>
</tr>
<tr>
<td><code>spark.shuffle.maxChunksBeingTransferred</code></td>
<td>Long.MAX_VALUE</td>
<td>
The max number of chunks allowed to being transferred at the same time on shuffle service.
</td>
</tr>
<tr>
<td><code>spark.shuffle.sort.bypassMergeThreshold</code></td>
<td>200</td>
Expand Down

0 comments on commit 799e131

Please sign in to comment.