Skip to content
Permalink
Browse files

[SPARK-26604][CORE][BACKPORT-2.4] Clean up channel registration for S…

…treamManager

## What changes were proposed in this pull request?

This is mostly a clean backport of #23521 to branch-2.4

## How was this patch tested?

I've tested this with a hack in `TransportRequestHandler` to force `ChunkFetchRequest` to get dropped.

Then making a number of `ExternalShuffleClient.fetchChunk` requests (which `OpenBlocks` then `ChunkFetchRequest`) and closing out of my test harness. A heap dump later reveals that the `StreamState` references are unreachable.

I haven't run this through the unit test suite, but doing that now. Wanted to get this up as I think folks are waiting for it for 2.4.1

Closes #24013 from abellina/SPARK-26604_cherry_pick_2_4.

Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Co-authored-by: Alessandro Bellina <abellina@yahoo-inc.com>
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
(cherry picked from commit 216eeec)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information...
2 people authored and vanzin committed Mar 8, 2019
1 parent a1ca566 commit c45f8da3af6000645ee76544940a6bdc5477884b
@@ -23,6 +23,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -49,7 +50,7 @@
final Iterator<ManagedBuffer> buffers;

// The channel associated to the stream
Channel associatedChannel = null;
final Channel associatedChannel;

// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
@@ -58,9 +59,10 @@
// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;

StreamState(String appId, Iterator<ManagedBuffer> buffers) {
StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
this.associatedChannel = channel;
}
}

@@ -71,13 +73,6 @@ public OneForOneStreamManager() {
streams = new ConcurrentHashMap<>();
}

@Override
public void registerChannel(Channel channel, long streamId) {
if (streams.containsKey(streamId)) {
streams.get(streamId).associatedChannel = channel;
}
}

@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
@@ -195,11 +190,19 @@ public long chunksBeingTransferred() {
*
* If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
*
* This method also associates the stream with a single client connection, which is guaranteed
* to be the only reader of the stream. Once the connection is closed, the stream will never
* be used again, enabling cleanup by `connectionTerminated`.
*/
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
long myStreamId = nextStreamId.getAndIncrement();
streams.put(myStreamId, new StreamState(appId, buffers));
streams.put(myStreamId, new StreamState(appId, buffers, channel));
return myStreamId;
}

@VisibleForTesting
public int numStreamStates() {
return streams.size();
}
}
@@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) {
throw new UnsupportedOperationException();
}

/**
* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
*
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
* times with the same channel and stream id.
*/
public void registerChannel(Channel channel, long streamId) { }

/**
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
@@ -133,7 +133,6 @@ private void processFetchRequest(final ChunkFetchRequest req) {
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format("Error opening block %s for request from %s",
@@ -62,8 +62,10 @@ public void handleFetchRequestAndStreamRequest() throws Exception {
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
streamManager.registerChannel(channel, streamId);
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);

assert streamManager.numStreamStates() == 1;

TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);
@@ -98,6 +100,9 @@ public void handleFetchRequestAndStreamRequest() throws Exception {
requestHandler.handle(request3);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;

streamManager.connectionTerminated(channel);
assert streamManager.numStreamStates() == 0;
}

private class ExtendedChannelPromise extends DefaultChannelPromise {
@@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
buffers.add(buffer1);
buffers.add(buffer2);
long streamId = manager.registerStream("appId", buffers.iterator());

Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerChannel(dummyChannel, streamId);
manager.registerStream("appId", buffers.iterator(), dummyChannel);
assert manager.numStreamStates() == 1;

manager.connectionTerminated(dummyChannel);

Mockito.verify(buffer1, Mockito.times(1)).release();
Mockito.verify(buffer2, Mockito.times(1)).release();
assert manager.numStreamStates() == 0;
}
}
@@ -91,7 +91,7 @@ protected void handleMessage(
OpenBlocks msg = (OpenBlocks) msgObj;
checkAuth(client, msg.appId);
long streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
if (logger.isTraceEnabled()) {
logger.trace("Registered streamId {} with {} buffers for client {} from host {}",
streamId,
@@ -101,7 +101,8 @@ public void testOpenShuffleBlocks() {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
any());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
@@ -59,7 +59,8 @@ class NettyBlockRpcServer(
val blocksNum = openBlocks.blockIds.length
val blocks = for (i <- (0 until blocksNum).view)
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
client.getChannel)
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)

0 comments on commit c45f8da

Please sign in to comment.
You can’t perform that action at this time.