Skip to content
Browse files

[SPARK-26604][CORE][BACKPORT-2.4] Clean up channel registration for S…


## What changes were proposed in this pull request?

This is mostly a clean backport of #23521 to branch-2.4

## How was this patch tested?

I've tested this with a hack in `TransportRequestHandler` to force `ChunkFetchRequest` to get dropped.

Then making a number of `ExternalShuffleClient.fetchChunk` requests (which `OpenBlocks` then `ChunkFetchRequest`) and closing out of my test harness. A heap dump later reveals that the `StreamState` references are unreachable.

I haven't run this through the unit test suite, but doing that now. Wanted to get this up as I think folks are waiting for it for 2.4.1

Closes #24013 from abellina/SPARK-26604_cherry_pick_2_4.

Lead-authored-by: Liang-Chi Hsieh <>
Co-authored-by: Alessandro Bellina <>
Signed-off-by: Marcelo Vanzin <>
(cherry picked from commit 216eeec)
Signed-off-by: Marcelo Vanzin <>
  • Loading branch information...
2 people authored and vanzin committed Mar 8, 2019
1 parent a1ca566 commit c45f8da3af6000645ee76544940a6bdc5477884b
@@ -23,6 +23,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -49,7 +50,7 @@
final Iterator<ManagedBuffer> buffers;

// The channel associated to the stream
Channel associatedChannel = null;
final Channel associatedChannel;

// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
@@ -58,9 +59,10 @@
// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;

StreamState(String appId, Iterator<ManagedBuffer> buffers) {
StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
this.associatedChannel = channel;

@@ -71,13 +73,6 @@ public OneForOneStreamManager() {
streams = new ConcurrentHashMap<>();

public void registerChannel(Channel channel, long streamId) {
if (streams.containsKey(streamId)) {
streams.get(streamId).associatedChannel = channel;

public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
@@ -195,11 +190,19 @@ public long chunksBeingTransferred() {
* If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
* This method also associates the stream with a single client connection, which is guaranteed
* to be the only reader of the stream. Once the connection is closed, the stream will never
* be used again, enabling cleanup by `connectionTerminated`.
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
long myStreamId = nextStreamId.getAndIncrement();
streams.put(myStreamId, new StreamState(appId, buffers));
streams.put(myStreamId, new StreamState(appId, buffers, channel));
return myStreamId;

public int numStreamStates() {
return streams.size();
@@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) {
throw new UnsupportedOperationException();

* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
* times with the same channel and stream id.
public void registerChannel(Channel channel, long streamId) { }

* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
@@ -133,7 +133,6 @@ private void processFetchRequest(final ChunkFetchRequest req) {
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format("Error opening block %s for request from %s",
@@ -62,8 +62,10 @@ public void handleFetchRequestAndStreamRequest() throws Exception {
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
streamManager.registerChannel(channel, streamId);
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);

assert streamManager.numStreamStates() == 1;

TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);
@@ -98,6 +100,9 @@ public void handleFetchRequestAndStreamRequest() throws Exception {
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;

assert streamManager.numStreamStates() == 0;

private class ExtendedChannelPromise extends DefaultChannelPromise {
@@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
long streamId = manager.registerStream("appId", buffers.iterator());

Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerChannel(dummyChannel, streamId);
manager.registerStream("appId", buffers.iterator(), dummyChannel);
assert manager.numStreamStates() == 1;


Mockito.verify(buffer1, Mockito.times(1)).release();
Mockito.verify(buffer2, Mockito.times(1)).release();
assert manager.numStreamStates() == 0;
@@ -91,7 +91,7 @@ protected void handleMessage(
OpenBlocks msg = (OpenBlocks) msgObj;
checkAuth(client, msg.appId);
long streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
if (logger.isTraceEnabled()) {
logger.trace("Registered streamId {} with {} buffers for client {} from host {}",
@@ -101,7 +101,8 @@ public void testOpenShuffleBlocks() {
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
Iterator<ManagedBuffer> buffers = stream.getValue();
@@ -59,7 +59,8 @@ class NettyBlockRpcServer(
val blocksNum = openBlocks.blockIds.length
val blocks = for (i <- (0 until blocksNum).view)
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)

0 comments on commit c45f8da

Please sign in to comment.
You can’t perform that action at this time.