From 82e3fa72af108c92a179ea441bf67e8b2e494d23 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Thu, 10 Aug 2017 13:29:13 +0800 Subject: [PATCH 1/4] [FLINK-7406][network] Implement Netty receiver incoming pipeline for credit-based --- .../netty/CreditBasedClientHandler.java | 277 ++++++++ .../io/network/netty/NettyMessage.java | 15 +- .../netty/PartitionRequestClientHandler.java | 8 +- .../network/netty/PartitionRequestQueue.java | 3 +- .../consumer/RemoteInputChannel.java | 257 +++++-- .../netty/NettyMessageSerializationTest.java | 3 +- .../PartitionRequestClientHandlerTest.java | 151 ++-- .../partition/InputGateConcurrentTest.java | 2 +- .../partition/InputGateFairnessTest.java | 8 +- .../consumer/RemoteInputChannelTest.java | 665 ++++++++++++++++-- 10 files changed, 1175 insertions(+), 214 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java new file mode 100644 index 0000000000000..1f1858843efc8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.netty; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException; +import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException; +import org.apache.flink.runtime.io.network.netty.exception.TransportException; +import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; +import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Channel handler to read the messages of buffer response or error response from the + * producer, to write and flush the unannounced credits for the producer. + */ +class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(CreditBasedClientHandler.class); + + /** Channels, which already requested partitions from the producers. */ + private final ConcurrentMap inputChannels = new ConcurrentHashMap<>(); + + private final AtomicReference channelError = new AtomicReference<>(); + + /** + * Set of cancelled partition requests. A request is cancelled iff an input channel is cleared + * while data is still coming in for this channel. + */ + private final ConcurrentMap cancelled = new ConcurrentHashMap<>(); + + private volatile ChannelHandlerContext ctx; + + // ------------------------------------------------------------------------ + // Input channel/receiver registration + // ------------------------------------------------------------------------ + + void addInputChannel(RemoteInputChannel listener) throws IOException { + checkError(); + + if (!inputChannels.containsKey(listener.getInputChannelId())) { + inputChannels.put(listener.getInputChannelId(), listener); + } + } + + void removeInputChannel(RemoteInputChannel listener) { + inputChannels.remove(listener.getInputChannelId()); + } + + void cancelRequestFor(InputChannelID inputChannelId) { + if (inputChannelId == null || ctx == null) { + return; + } + + if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) { + ctx.writeAndFlush(new NettyMessage.CancelPartitionRequest(inputChannelId)); + } + } + + // ------------------------------------------------------------------------ + // Network events + // ------------------------------------------------------------------------ + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + if (this.ctx == null) { + this.ctx = ctx; + } + + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Unexpected close. In normal operation, the client closes the connection after all input + // channels have been removed. This indicates a problem with the remote task manager. + if (!inputChannels.isEmpty()) { + final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + notifyAllChannelsOfErrorAndClose(new RemoteTransportException( + "Connection unexpectedly closed by remote task manager '" + remoteAddr + "'. " + + "This might indicate that the remote task manager was lost.", remoteAddr)); + } + + super.channelInactive(ctx); + } + + /** + * Called on exceptions in the client handler pipeline. + * + *

Remote exceptions are received as regular payload. + */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + + if (cause instanceof TransportException) { + notifyAllChannelsOfErrorAndClose(cause); + } else { + final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + final TransportException tex; + + // Improve on the connection reset by peer error message + if (cause instanceof IOException && cause.getMessage().equals("Connection reset by peer")) { + tex = new RemoteTransportException("Lost connection to task manager '" + remoteAddr + "'. " + + "This indicates that the remote task manager was lost.", remoteAddr, cause); + } else { + tex = new LocalTransportException(cause.getMessage(), ctx.channel().localAddress(), cause); + } + + notifyAllChannelsOfErrorAndClose(tex); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + try { + decodeMsg(msg); + } catch (Throwable t) { + notifyAllChannelsOfErrorAndClose(t); + } + } + + private void notifyAllChannelsOfErrorAndClose(Throwable cause) { + if (channelError.compareAndSet(null, cause)) { + try { + for (RemoteInputChannel inputChannel : inputChannels.values()) { + inputChannel.onError(cause); + } + } catch (Throwable t) { + // We can only swallow the Exception at this point. :( + LOG.warn("An Exception was thrown during error notification of a remote input channel.", t); + } finally { + inputChannels.clear(); + + if (ctx != null) { + ctx.close(); + } + } + } + } + + // ------------------------------------------------------------------------ + + /** + * Checks for an error and rethrows it if one was reported. + */ + private void checkError() throws IOException { + final Throwable t = channelError.get(); + + if (t != null) { + if (t instanceof IOException) { + throw (IOException) t; + } else { + throw new IOException("There has been an error in the channel.", t); + } + } + } + + private void decodeMsg(Object msg) throws Throwable { + final Class msgClazz = msg.getClass(); + + // ---- Buffer -------------------------------------------------------- + if (msgClazz == NettyMessage.BufferResponse.class) { + NettyMessage.BufferResponse bufferOrEvent = (NettyMessage.BufferResponse) msg; + + RemoteInputChannel inputChannel = inputChannels.get(bufferOrEvent.receiverId); + if (inputChannel == null) { + bufferOrEvent.releaseBuffer(); + + cancelRequestFor(bufferOrEvent.receiverId); + + return; + } + + decodeBufferOrEvent(inputChannel, bufferOrEvent); + + } else if (msgClazz == NettyMessage.ErrorResponse.class) { + // ---- Error --------------------------------------------------------- + NettyMessage.ErrorResponse error = (NettyMessage.ErrorResponse) msg; + + SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + if (error.isFatalError()) { + notifyAllChannelsOfErrorAndClose(new RemoteTransportException( + "Fatal error at remote task manager '" + remoteAddr + "'.", + remoteAddr, + error.cause)); + } else { + RemoteInputChannel inputChannel = inputChannels.get(error.receiverId); + + if (inputChannel != null) { + if (error.cause.getClass() == PartitionNotFoundException.class) { + inputChannel.onFailedPartitionRequest(); + } else { + inputChannel.onError(new RemoteTransportException( + "Error at remote task manager '" + remoteAddr + "'.", + remoteAddr, + error.cause)); + } + } + } + } else { + throw new IllegalStateException("Received unknown message from producer: " + msg.getClass()); + } + } + + private void decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessage.BufferResponse bufferOrEvent) throws Throwable { + try { + if (bufferOrEvent.isBuffer()) { + // ---- Buffer ------------------------------------------------ + + // Early return for empty buffers. Otherwise Netty's readBytes() throws an + // IndexOutOfBoundsException. + if (bufferOrEvent.getSize() == 0) { + inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog); + return; + } + + Buffer buffer = inputChannel.requestBuffer(); + if (buffer != null) { + buffer.setSize(bufferOrEvent.getSize()); + bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer()); + + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog); + } else if (inputChannel.isReleased()) { + cancelRequestFor(bufferOrEvent.receiverId); + } else { + throw new IllegalStateException("No buffer available in credit-based input channel."); + } + } else { + // ---- Event ------------------------------------------------- + // TODO We can just keep the serialized data in the Netty buffer and release it later at the reader + byte[] byteArray = new byte[bufferOrEvent.getSize()]; + bufferOrEvent.getNettyBuffer().readBytes(byteArray); + + MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray); + Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false); + + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog); + } + } finally { + bufferOrEvent.releaseBuffer(); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index 89fb9e85e3301..db1b899b83209 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -221,6 +221,8 @@ static class BufferResponse extends NettyMessage { final int sequenceNumber; + final int backlog; + // ---- Deserialization ----------------------------------------------- final boolean isBuffer; @@ -232,7 +234,8 @@ static class BufferResponse extends NettyMessage { private BufferResponse( ByteBuf retainedSlice, boolean isBuffer, int sequenceNumber, - InputChannelID receiverId) { + InputChannelID receiverId, + int backlog) { // When deserializing we first have to request a buffer from the respective buffer // provider (at the handler) and copy the buffer from Netty's space to ours. Only // retainedSlice is set in this case. @@ -242,15 +245,17 @@ private BufferResponse( this.isBuffer = isBuffer; this.sequenceNumber = sequenceNumber; this.receiverId = checkNotNull(receiverId); + this.backlog = backlog; } - BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId) { + BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId, int backlog) { this.buffer = checkNotNull(buffer); this.retainedSlice = null; this.isBuffer = buffer.isBuffer(); this.size = buffer.getSize(); this.sequenceNumber = sequenceNumber; this.receiverId = checkNotNull(receiverId); + this.backlog = backlog; } boolean isBuffer() { @@ -280,7 +285,7 @@ void releaseBuffer() { ByteBuf write(ByteBufAllocator allocator) throws IOException { checkNotNull(buffer, "No buffer instance to serialize."); - int length = 16 + 4 + 1 + 4 + buffer.getSize(); + int length = 16 + 4 + 4 + 1 + 4 + buffer.getSize(); ByteBuf result = null; try { @@ -288,6 +293,7 @@ ByteBuf write(ByteBufAllocator allocator) throws IOException { receiverId.writeTo(result); result.writeInt(sequenceNumber); + result.writeInt(backlog); result.writeBoolean(buffer.isBuffer()); result.writeInt(buffer.getSize()); result.writeBytes(buffer.getNioBuffer()); @@ -309,12 +315,13 @@ ByteBuf write(ByteBufAllocator allocator) throws IOException { static BufferResponse readFrom(ByteBuf buffer) { InputChannelID receiverId = InputChannelID.fromByteBuf(buffer); int sequenceNumber = buffer.readInt(); + int backlog = buffer.readInt(); boolean isBuffer = buffer.readBoolean(); int size = buffer.readInt(); ByteBuf retainedSlice = buffer.readSlice(size).retain(); - return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId); + return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId, backlog); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java index 566b215b9979d..ab4798e21720c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java @@ -276,7 +276,7 @@ private boolean decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessag // Early return for empty buffers. Otherwise Netty's readBytes() throws an // IndexOutOfBoundsException. if (bufferOrEvent.getSize() == 0) { - inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber); + inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, -1); return true; } @@ -295,7 +295,7 @@ private boolean decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessag buffer.setSize(bufferOrEvent.getSize()); bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer()); - inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber); + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1); return true; } @@ -318,7 +318,7 @@ else if (bufferProvider.isDestroyed()) { MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray); Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false); - inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber); + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1); return true; } @@ -450,7 +450,7 @@ public void run() { RemoteInputChannel inputChannel = inputChannels.get(stagedBufferResponse.receiverId); if (inputChannel != null) { - inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber); + inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber, -1); success = true; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java index ff0f1307dbfc7..41f87ae8c1731 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java @@ -193,7 +193,8 @@ private void writeAndFlushNextMessageIfPossible(final Channel channel) throws IO BufferResponse msg = new BufferResponse( next.buffer(), reader.getSequenceNumber(), - reader.getReceiverId()); + reader.getReceiverId(), + 0); if (isEndOfPartitionEvent(next.buffer())) { reader.notifySubpartitionConsumed(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index cd00934e2eff6..02c7b34863ed4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.ConnectionID; @@ -32,11 +33,13 @@ import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.util.ExceptionUtils; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.ArrayDeque; +import java.util.Collections; import java.util.List; import java.util.ArrayList; -import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -82,17 +85,19 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, /** The initial number of exclusive buffers assigned to this channel. */ private int initialCredit; - /** The current available buffers including both exclusive buffers and requested floating buffers. */ - private final ArrayDeque availableBuffers = new ArrayDeque<>(); + /** The available buffer queue wraps both exclusive and requested floating buffers. */ + private final AvailableBufferQueue bufferQueue = new AvailableBufferQueue(); /** The number of available buffers that have not been announced to the producer yet. */ private final AtomicInteger unannouncedCredit = new AtomicInteger(0); - /** The number of unsent buffers in the producer's sub partition. */ - private final AtomicInteger senderBacklog = new AtomicInteger(0); + /** The number of required buffers that equals to sender's backlog plus initial credit. */ + @GuardedBy("bufferQueue") + private int numRequiredBuffers; /** The tag indicates whether this channel is waiting for additional floating buffers from the buffer pool. */ - private final AtomicBoolean isWaitingForFloatingBuffers = new AtomicBoolean(false); + @GuardedBy("bufferQueue") + private boolean isWaitingForFloatingBuffers; public RemoteInputChannel( SingleInputGate inputGate, @@ -133,10 +138,11 @@ void assignExclusiveSegments(List segments) { checkArgument(segments.size() > 0, "The number of exclusive buffers per channel should be larger than 0."); this.initialCredit = segments.size(); + this.numRequiredBuffers = segments.size(); - synchronized(availableBuffers) { + synchronized(bufferQueue) { for (MemorySegment segment : segments) { - availableBuffers.add(new Buffer(segment, this)); + bufferQueue.addExclusiveBuffer(new Buffer(segment, this), numRequiredBuffers); } } } @@ -211,7 +217,7 @@ void sendTaskEvent(TaskEvent event) throws IOException { // ------------------------------------------------------------------------ @Override - boolean isReleased() { + public boolean isReleased() { return isReleased.get(); } @@ -227,7 +233,8 @@ void notifySubpartitionConsumed() { void releaseAllResources() throws IOException { if (isReleased.compareAndSet(false, true)) { - // Gather all exclusive buffers and recycle them to global pool in batch + // Gather all exclusive buffers and recycle them to global pool in batch, because + // we do not want to trigger redistribution of buffers after each recycle. final List exclusiveRecyclingSegments = new ArrayList<>(); synchronized (receivedBuffers) { @@ -240,16 +247,8 @@ void releaseAllResources() throws IOException { } } } - - synchronized (availableBuffers) { - Buffer buffer; - while ((buffer = availableBuffers.poll()) != null) { - if (buffer.getRecycler() == this) { - exclusiveRecyclingSegments.add(buffer.getMemorySegment()); - } else { - buffer.recycle(); - } - } + synchronized (bufferQueue) { + bufferQueue.releaseAll(exclusiveRecyclingSegments); } if (exclusiveRecyclingSegments.size() > 0) { @@ -287,81 +286,93 @@ void notifyCreditAvailable() { } /** - * Exclusive buffer is recycled to this input channel directly and it may trigger notify - * credit to producer. + * Exclusive buffer is recycled to this input channel directly and it may trigger return extra + * floating buffer and notify increased credit to the producer. * * @param segment The exclusive segment of this channel. */ @Override public void recycle(MemorySegment segment) { - synchronized (availableBuffers) { - // Important: the isReleased check should be inside the synchronized block. - // that way the segment can also be returned to global pool after added into - // the available queue during releasing all resources. + int numAddedBuffers; + + synchronized (bufferQueue) { + // Important: check the isReleased state inside synchronized block, so there is no + // race condition when recycle and releaseAllResources running in parallel. if (isReleased.get()) { try { - inputGate.returnExclusiveSegments(Arrays.asList(segment)); + inputGate.returnExclusiveSegments(Collections.singletonList(segment)); return; } catch (Throwable t) { ExceptionUtils.rethrow(t); } } - availableBuffers.add(new Buffer(segment, this)); + numAddedBuffers = bufferQueue.addExclusiveBuffer(new Buffer(segment, this), numRequiredBuffers); } - if (unannouncedCredit.getAndAdd(1) == 0) { + if (numAddedBuffers > 0 && unannouncedCredit.getAndAdd(numAddedBuffers) == 0) { notifyCreditAvailable(); } } public int getNumberOfAvailableBuffers() { - synchronized (availableBuffers) { - return availableBuffers.size(); + synchronized (bufferQueue) { + return bufferQueue.getAvailableBufferSize(); } } + @VisibleForTesting + public int getNumberOfRequiredBuffers() { + return numRequiredBuffers; + } + /** * The Buffer pool notifies this channel of an available floating buffer. If the channel is released or * currently does not need extra buffers, the buffer should be recycled to the buffer pool. Otherwise, - * the buffer will be added into the availableBuffers queue and the unannounced credit is - * increased by one. + * the buffer will be added into the bufferQueue and the unannounced credit is increased + * by one. * * @param buffer Buffer that becomes available in buffer pool. * @return True when this channel is waiting for more floating buffers, otherwise false. */ @Override public boolean notifyBufferAvailable(Buffer buffer) { - checkState(isWaitingForFloatingBuffers.get(), "This channel should be waiting for floating buffers."); + // Check the isReleased state outside synchronized block first to avoid + // deadlock with releaseAllResources running in parallel. + if (isReleased.get()) { + buffer.recycle(); + return false; + } - synchronized (availableBuffers) { - // Important: the isReleased check should be inside the synchronized block. - if (isReleased.get() || availableBuffers.size() >= senderBacklog.get()) { - isWaitingForFloatingBuffers.set(false); - buffer.recycle(); + boolean needMoreBuffers = false; + synchronized (bufferQueue) { + checkState(isWaitingForFloatingBuffers, "This channel should be waiting for floating buffers."); + // Important: double check the isReleased state inside synchronized block, so there is no + // race condition when notifyBufferAvailable and releaseAllResources running in parallel. + if (isReleased.get() || bufferQueue.getAvailableBufferSize() >= numRequiredBuffers) { + buffer.recycle(); return false; } - availableBuffers.add(buffer); - - if (unannouncedCredit.getAndAdd(1) == 0) { - notifyCreditAvailable(); - } + bufferQueue.addFloatingBuffer(buffer); - if (availableBuffers.size() >= senderBacklog.get()) { - isWaitingForFloatingBuffers.set(false); - return false; + if (bufferQueue.getAvailableBufferSize() == numRequiredBuffers) { + isWaitingForFloatingBuffers = false; } else { - return true; + needMoreBuffers = true; } } + + if (unannouncedCredit.getAndAdd(1) == 0) { + notifyCreditAvailable(); + } + + return needMoreBuffers; } @Override public void notifyBufferDestroyed() { - if (!isWaitingForFloatingBuffers.compareAndSet(true, false)) { - throw new IllegalStateException("This channel should be waiting for floating buffers currently."); - } + // Nothing to do actually. } // ------------------------------------------------------------------------ @@ -394,7 +405,58 @@ public BufferProvider getBufferProvider() throws IOException { return inputGate.getBufferProvider(); } - public void onBuffer(Buffer buffer, int sequenceNumber) { + /** + * Requests buffer from input channel directly for receiving network data. + * It should always return an available buffer in credit-based mode unless + * the channel has been released. + * + * @return The available buffer. + */ + @Nullable + public Buffer requestBuffer() { + synchronized (bufferQueue) { + return bufferQueue.takeBuffer(); + } + } + + /** + * Receives the backlog from the producer's buffer response. If the number of available + * buffers is less than backlog + initialCredit, it will request floating buffers from the buffer + * pool, and then notify unannounced credits to the producer. + * + * @param backlog The number of unsent buffers in the producer's sub partition. + */ + @VisibleForTesting + void onSenderBacklog(int backlog) throws IOException { + int numRequestedBuffers = 0; + + synchronized (bufferQueue) { + // Important: check the isReleased state inside synchronized block, so there is no + // race condition when onSenderBacklog and releaseAllResources running in parallel. + if (isReleased.get()) { + return; + } + + numRequiredBuffers = backlog + initialCredit; + while (bufferQueue.getAvailableBufferSize() < numRequiredBuffers && !isWaitingForFloatingBuffers) { + Buffer buffer = inputGate.getBufferPool().requestBuffer(); + if (buffer != null) { + bufferQueue.addFloatingBuffer(buffer); + numRequestedBuffers++; + } else if (inputGate.getBufferProvider().addBufferListener(this)) { + // If the channel has not got enough buffers, register it as listener to wait for more floating buffers. + isWaitingForFloatingBuffers = true; + break; + } + } + } + + if (numRequestedBuffers > 0 && unannouncedCredit.getAndAdd(numRequestedBuffers) == 0) { + notifyCreditAvailable(); + } + } + + public void onBuffer(Buffer buffer, int sequenceNumber, int backlog) throws IOException { boolean success = false; try { @@ -416,6 +478,10 @@ public void onBuffer(Buffer buffer, int sequenceNumber) { } } } + + if (success && backlog >= 0) { + onSenderBacklog(backlog); + } } finally { if (!success) { buffer.recycle(); @@ -423,16 +489,23 @@ public void onBuffer(Buffer buffer, int sequenceNumber) { } } - public void onEmptyBuffer(int sequenceNumber) { + public void onEmptyBuffer(int sequenceNumber, int backlog) throws IOException { + boolean success = false; + synchronized (receivedBuffers) { if (!isReleased.get()) { if (expectedSequenceNumber == sequenceNumber) { expectedSequenceNumber++; + success = true; } else { onError(new BufferReorderingException(expectedSequenceNumber, sequenceNumber)); } } } + + if (success && backlog >= 0) { + onSenderBacklog(backlog); + } } public void onFailedPartitionRequest() { @@ -462,4 +535,82 @@ public String getMessage() { expectedSequenceNumber, actualSequenceNumber); } } + + /** + * Manages the exclusive and floating buffers of this channel, and handles the + * internal buffer related logic. + */ + private static class AvailableBufferQueue { + + /** The current available floating buffers from the fixed buffer pool. */ + private final ArrayDeque floatingBuffers; + + /** The current available exclusive buffers from the global buffer pool. */ + private final ArrayDeque exclusiveBuffers; + + AvailableBufferQueue() { + this.exclusiveBuffers = new ArrayDeque<>(); + this.floatingBuffers = new ArrayDeque<>(); + } + + /** + * Adds an exclusive buffer (back) into the queue and recycles one floating buffer if the + * number of available buffers in queue is more than the required amount. + * + * @param buffer The exclusive buffer to add + * @param numRequiredBuffers The number of required buffers + * + * @return How many buffers were added to the queue + */ + int addExclusiveBuffer(Buffer buffer, int numRequiredBuffers) { + exclusiveBuffers.add(buffer); + if (getAvailableBufferSize() > numRequiredBuffers) { + Buffer floatingBuffer = floatingBuffers.poll(); + floatingBuffer.recycle(); + return 0; + } else { + return 1; + } + } + + void addFloatingBuffer(Buffer buffer) { + floatingBuffers.add(buffer); + } + + /** + * Takes the floating buffer first in order to make full use of floating + * buffers reasonably. + * + * @return An available floating or exclusive buffer, may be null + * if the channel is released. + */ + @Nullable + Buffer takeBuffer() { + if (floatingBuffers.size() > 0) { + return floatingBuffers.poll(); + } else { + return exclusiveBuffers.poll(); + } + } + + /** + * The floating buffer is recycled to local buffer pool directly, and the + * exclusive buffer will be gathered to return to global buffer pool later. + * + * @param exclusiveSegments The list that we will add exclusive segments into. + */ + void releaseAll(List exclusiveSegments) { + Buffer buffer; + while ((buffer = floatingBuffers.poll()) != null) { + buffer.recycle(); + } + while ((buffer = exclusiveBuffers.poll()) != null) { + exclusiveSegments.add(buffer.getMemorySegment()); + } + } + + int getAvailableBufferSize() { + return floatingBuffers.size() + exclusiveBuffers.size(); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java index 0651f9782643b..8c87cebca2471 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java @@ -62,7 +62,7 @@ public void testEncodeDecode() { nioBuffer.putInt(i); } - NettyMessage.BufferResponse expected = new NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID()); + NettyMessage.BufferResponse expected = new NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID(), random.nextInt()); NettyMessage.BufferResponse actual = encodeAndDecode(expected); // Verify recycle has been called on buffer instance @@ -85,6 +85,7 @@ public void testEncodeDecode() { assertEquals(expected.sequenceNumber, actual.sequenceNumber); assertEquals(expected.receiverId, actual.receiverId); + assertEquals(expected.backlog, actual.backlog); } { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java index e1e5bd349f735..d3ff6c26afc26 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java @@ -30,23 +30,16 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.util.TestBufferFactory; -import org.apache.flink.runtime.testutils.DiscardingRecycler; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; -import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import java.io.IOException; -import java.util.concurrent.atomic.AtomicReference; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -80,19 +73,19 @@ public void testReleaseInputChannelDuringDecode() throws Exception { when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); when(inputChannel.getBufferProvider()).thenReturn(bufferProvider); - final BufferResponse ReceivedBuffer = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId()); + final BufferResponse receivedBuffer = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); client.addInputChannel(inputChannel); - client.channelRead(mock(ChannelHandlerContext.class), ReceivedBuffer); + client.channelRead(mock(ChannelHandlerContext.class), receivedBuffer); } /** * Tests a fix for FLINK-1761. * - *

FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0. + *

FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0. */ @Test public void testReceiveEmptyBuffer() throws Exception { @@ -108,10 +101,11 @@ public void testReceiveEmptyBuffer() throws Exception { final Buffer emptyBuffer = TestBufferFactory.createBuffer(); emptyBuffer.setSize(0); + final int backlog = 2; final BufferResponse receivedBuffer = createBufferResponse( - emptyBuffer, 0, inputChannel.getInputChannelId()); + emptyBuffer, 0, inputChannel.getInputChannelId(), backlog); - final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); + final CreditBasedClientHandler client = new CreditBasedClientHandler(); client.addInputChannel(inputChannel); // Read the empty buffer @@ -119,6 +113,51 @@ public void testReceiveEmptyBuffer() throws Exception { // This should not throw an exception verify(inputChannel, never()).onError(any(Throwable.class)); + verify(inputChannel, times(1)).onEmptyBuffer(0, backlog); + } + + /** + * Verifies that {@link RemoteInputChannel#onBuffer(Buffer, int, int)} is called when a + * {@link BufferResponse} is received. + */ + @Test + public void testReceiveBuffer() throws Exception { + final Buffer buffer = TestBufferFactory.createBuffer(); + final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); + when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); + when(inputChannel.requestBuffer()).thenReturn(buffer); + + final int backlog = 2; + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), backlog); + + final CreditBasedClientHandler client = new CreditBasedClientHandler(); + client.addInputChannel(inputChannel); + + client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog); + } + + /** + * Verifies that {@link RemoteInputChannel#onError(Throwable)} is called when a + * {@link BufferResponse} is received but no available buffer in input channel. + */ + @Test + public void testThrowExceptionForNoAvailableBuffer() throws Exception { + final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); + when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); + when(inputChannel.requestBuffer()).thenReturn(null); + + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + + final CreditBasedClientHandler client = new CreditBasedClientHandler(); + client.addInputChannel(inputChannel); + + client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + verify(inputChannel, times(1)).onError(any(IllegalStateException.class)); } /** @@ -136,8 +175,8 @@ public void testReceivePartitionNotFoundException() throws Exception { when(inputChannel.getBufferProvider()).thenReturn(bufferProvider); final ErrorResponse partitionNotFound = new ErrorResponse( - new PartitionNotFoundException(new ResultPartitionID()), - inputChannel.getInputChannelId()); + new PartitionNotFoundException(new ResultPartitionID()), + inputChannel.getInputChannelId()); final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); client.addInputChannel(inputChannel); @@ -169,95 +208,19 @@ public void testCancelBeforeActive() throws Exception { client.cancelRequestFor(inputChannel.getInputChannelId()); } - /** - * Tests that an unsuccessful message decode call for a staged message - * does not leave the channel with auto read set to false. - */ - @Test - @SuppressWarnings("unchecked") - public void testAutoReadAfterUnsuccessfulStagedMessage() throws Exception { - PartitionRequestClientHandler handler = new PartitionRequestClientHandler(); - EmbeddedChannel channel = new EmbeddedChannel(handler); - - final AtomicReference listener = new AtomicReference<>(); - - BufferProvider bufferProvider = mock(BufferProvider.class); - when(bufferProvider.addBufferListener(any(BufferListener.class))).thenAnswer(new Answer() { - @Override - @SuppressWarnings("unchecked") - public Boolean answer(InvocationOnMock invocation) throws Throwable { - listener.set((BufferListener) invocation.getArguments()[0]); - return true; - } - }); - - when(bufferProvider.requestBuffer()).thenReturn(null); - - InputChannelID channelId = new InputChannelID(0, 0); - RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); - when(inputChannel.getInputChannelId()).thenReturn(channelId); - - // The 3rd staged msg has a null buffer provider - when(inputChannel.getBufferProvider()).thenReturn(bufferProvider, bufferProvider, null); - - handler.addInputChannel(inputChannel); - - BufferResponse msg = createBufferResponse(createBuffer(true), 0, channelId); - - // Write 1st buffer msg. No buffer is available, therefore the buffer - // should be staged and auto read should be set to false. - assertTrue(channel.config().isAutoRead()); - channel.writeInbound(msg); - - // No buffer available, auto read false - assertFalse(channel.config().isAutoRead()); - - // Write more buffers... all staged. - msg = createBufferResponse(createBuffer(true), 1, channelId); - channel.writeInbound(msg); - - msg = createBufferResponse(createBuffer(true), 2, channelId); - channel.writeInbound(msg); - - // Notify about buffer => handle 1st msg - Buffer availableBuffer = createBuffer(false); - listener.get().notifyBufferAvailable(availableBuffer); - - // Start processing of staged buffers (in run pending tasks). Make - // sure that the buffer provider acts like it's destroyed. - when(bufferProvider.addBufferListener(any(BufferListener.class))).thenReturn(false); - when(bufferProvider.isDestroyed()).thenReturn(true); - - // Execute all tasks that are scheduled in the event loop. Further - // eventLoop().execute() calls are directly executed, if they are - // called in the scope of this call. - channel.runPendingTasks(); - - assertTrue(channel.config().isAutoRead()); - } - // --------------------------------------------------------------------------------------------- - private static Buffer createBuffer(boolean fill) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(1024, null); - if (fill) { - for (int i = 0; i < 1024; i++) { - segment.put(i, (byte) i); - } - } - return new Buffer(segment, DiscardingRecycler.INSTANCE, true); - } - /** * Returns a deserialized buffer message as it would be received during runtime. */ private BufferResponse createBufferResponse( Buffer buffer, int sequenceNumber, - InputChannelID receivingChannelId) throws IOException { + InputChannelID receivingChannelId, + int backlog) throws IOException { // Mock buffer to serialize - BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId); + BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId, backlog); ByteBuf serialized = resp.write(UnpooledByteBufAllocator.DEFAULT); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java index 6f98119ed287f..81788c9ca638e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java @@ -216,7 +216,7 @@ private static class RemoteChannelSource extends Source { @Override void addBuffer(Buffer buffer) throws Exception { - channel.onBuffer(buffer, seq++); + channel.onBuffer(buffer, seq++, -1); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java index 324a0607718a7..4e90265a3200e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -206,9 +206,9 @@ gate, i, new ResultPartitionID(), mock(ConnectionID.class), channels[i] = channel; for (int p = 0; p < buffersPerChannel; p++) { - channel.onBuffer(mockBuffer, p); + channel.onBuffer(mockBuffer, p, -1); } - channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel); + channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel, -1); gate.setInputChannel(new IntermediateResultPartitionID(), channel); } @@ -263,7 +263,7 @@ gate, i, new ResultPartitionID(), mock(ConnectionID.class), gate.setInputChannel(new IntermediateResultPartitionID(), channel); } - channels[11].onBuffer(mockBuffer, 0); + channels[11].onBuffer(mockBuffer, 0, -1); channelSequenceNums[11]++; // read all the buffers and the EOF event @@ -325,7 +325,7 @@ private void fillRandom( Collections.shuffle(poss); for (int i : poss) { - partitions[i].onBuffer(buffer, sequenceNumbers[i]++); + partitions[i].onBuffer(buffer, sequenceNumbers[i]++, -1); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index d791ced7eb384..863f8865c6f40 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -18,24 +18,28 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.netty.PartitionRequestClient; import org.apache.flink.runtime.io.network.partition.ProducerFailedException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.util.TestBufferFactory; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import org.junit.Test; -import scala.Tuple2; import java.io.IOException; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; @@ -43,12 +47,14 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; +import scala.Tuple2; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyListOf; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -66,10 +72,10 @@ public void testExceptionOnReordering() throws Exception { final Buffer buffer = TestBufferFactory.createBuffer(); // The test - inputChannel.onBuffer(buffer.retain(), 0); + inputChannel.onBuffer(buffer.retain(), 0, -1); // This does not yet throw the exception, but sets the error at the channel. - inputChannel.onBuffer(buffer, 29); + inputChannel.onBuffer(buffer, 29, -1); try { inputChannel.getNextBuffer(); @@ -113,7 +119,7 @@ public Void call() throws Exception { for (int j = 0; j < 128; j++) { // this is the same buffer over and over again which will be // recycled by the RemoteInputChannel - inputChannel.onBuffer(buffer.retain(), j); + inputChannel.onBuffer(buffer.retain(), j, -1); } if (inputChannel.isReleased()) { @@ -301,81 +307,562 @@ public void testProducerFailedException() throws Exception { } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to available buffers directly and it triggers notify of announced credit. + * Tests to verify the behaviours of three different processes if the number of available + * buffers is less than required buffers. + * + * 1. Recycle the floating buffer + * 2. Recycle the exclusive buffer + * 3. Decrease the sender's backlog */ @Test - public void testRecycleExclusiveBufferBeforeReleased() throws Exception { - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); - - // Recycle exclusive segment - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + public void testAvailableBuffersLessThanRequiredBuffers() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 14; - assertEquals("There should be one buffer available after recycle.", - 1, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(1)).notifyCreditAvailable(); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Prepare the exclusive and floating buffers to verify recycle logic later + final Buffer exclusiveBuffer = inputChannel.requestBuffer(); + assertNotNull(exclusiveBuffer); + + final int numRecycleFloatingBuffers = 2; + final ArrayDeque floatingBufferQueue = new ArrayDeque<>(numRecycleFloatingBuffers); + for (int i = 0; i < numRecycleFloatingBuffers; i++) { + Buffer floatingBuffer = bufferPool.requestBuffer(); + assertNotNull(floatingBuffer); + floatingBufferQueue.add(floatingBuffer); + } - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + verify(bufferPool, times(numRecycleFloatingBuffers)).requestBuffer(); + + // Receive the producer's backlog more than the number of available floating buffers + inputChannel.onSenderBacklog(14); + + // The channel requests (backlog + numExclusiveBuffers) floating buffers from local pool. + // It does not get enough floating buffers and register as buffer listener + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 13 buffers available in the channel", + 13, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 16 buffers required in the channel", + 16, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Increase the backlog + inputChannel.onSenderBacklog(16); + + // The channel is already in the status of waiting for buffers and will not request any more + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 13 buffers available in the channel", + 13, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 18 buffers required in the channel", + 18, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one floating buffer + floatingBufferQueue.poll().recycle(); + + // Assign the floating buffer to the listener and the channel is still waiting for more floating buffers + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 18 buffers required in the channel", + 18, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one more floating buffer + floatingBufferQueue.poll().recycle(); + + // Assign the floating buffer to the listener and the channel is still waiting for more floating buffers + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 15 buffers available in the channel", + 15, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 18 buffers required in the channel", + 18, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Decrease the backlog + inputChannel.onSenderBacklog(15); + + // Only the number of required buffers is changed by (backlog + numExclusiveBuffers) + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 15 buffers available in the channel", + 15, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 17 buffers required in the channel", + 17, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one exclusive buffer + exclusiveBuffer.recycle(); + + // The exclusive buffer is returned to the channel directly + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 16 buffers available in the channel", + 16, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 17 buffers required in the channel", + 17, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffers available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); - assertEquals("There should be two buffers available after recycle.", - 2, inputChannel.getNumberOfAvailableBuffers()); - // It should be called only once when increased from zero. - verify(inputChannel, times(1)).notifyCreditAvailable(); + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to global pool via input gate when channel is released. + * Tests to verify the behaviours of recycling floating and exclusive buffers if the number of available + * buffers equals to required buffers. */ @Test - public void testRecycleExclusiveBufferAfterReleased() throws Exception { + public void testAvailableBuffersEqualToRequiredBuffers() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); - - inputChannel.releaseAllResources(); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 14; - // Recycle exclusive segment after channel released - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Prepare the exclusive and floating buffers to verify recycle logic later + final Buffer exclusiveBuffer = inputChannel.requestBuffer(); + assertNotNull(exclusiveBuffer); + final Buffer floatingBuffer = bufferPool.requestBuffer(); + assertNotNull(floatingBuffer); + verify(bufferPool, times(1)).requestBuffer(); + + // Receive the producer's backlog + inputChannel.onSenderBacklog(12); + + // The channel requests (backlog + numExclusiveBuffers) floating buffers from local pool + // and gets enough floating buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one floating buffer + floatingBuffer.recycle(); + + // The floating buffer is returned to local buffer directly because the channel is not waiting + // for floating buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 1 buffer available in local pool", + 1, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one exclusive buffer + exclusiveBuffer.recycle(); + + // Return one extra floating buffer to the local pool because the number of available buffers + // already equals to required buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 2 buffers available in local pool", + 2, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); - assertEquals("Resource leak during recycling buffer after channel is released.", - 0, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(0)).notifyCreditAvailable(); - verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class)); + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** - * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying the exclusive segments are - * recycled to global pool via input gate and no resource leak. + * Tests to verify the behaviours of recycling floating and exclusive buffers if the number of available + * buffers is more than required buffers by decreasing the sender's backlog. */ @Test - public void testReleaseExclusiveBuffers() throws Exception { + public void testAvailableBuffersMoreThanRequiredBuffers() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 14; + + final SingleInputGate inputGate = createSingleInputGate(); final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Prepare the exclusive and floating buffers to verify recycle logic later + final Buffer exclusiveBuffer = inputChannel.requestBuffer(); + assertNotNull(exclusiveBuffer); + + final Buffer floatingBuffer = bufferPool.requestBuffer(); + assertNotNull(floatingBuffer); + + verify(bufferPool, times(1)).requestBuffer(); + + // Receive the producer's backlog + inputChannel.onSenderBacklog(12); + + // The channel gets enough floating buffers from local pool + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Decrease the backlog to make the number of available buffers more than required buffers + inputChannel.onSenderBacklog(10); + + // Only the number of required buffers is changed by (backlog + numExclusiveBuffers) + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 12 buffers required in the channel", + 12, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one exclusive buffer + exclusiveBuffer.recycle(); + + // Return one extra floating buffer to the local pool because the number of available buffers + // is more than required buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 12 buffers required in the channel", + 12, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 1 buffer available in local pool", + 1, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one floating buffer + floatingBuffer.recycle(); + + // The floating buffer is returned to local pool directly because the channel is not waiting for + // floating buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 12 buffers required in the channel", + 12, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 2 buffers available in local pool", + 2, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); - // Assign exclusive segments to channel - final List exclusiveSegments = new ArrayList<>(); + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + + /** + * Tests to verify that the buffer pool will distribute available floating buffers among + * all the channel listeners in a fair way. + */ + @Test + public void testFairDistributionFloatingBuffers() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32); final int numExclusiveBuffers = 2; - for (int i = 0; i < numExclusiveBuffers; i++) { - exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + final int numFloatingBuffers = 3; + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate)); + inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1); + inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2); + inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Exhaust all the floating buffers + final List floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + // Receive the producer's backlog to trigger request floating buffers from pool + // and register as listeners as a result + channel1.onSenderBacklog(8); + channel2.onSenderBacklog(8); + channel3.onSenderBacklog(8); + + verify(bufferPool, times(1)).addBufferListener(channel1); + verify(bufferPool, times(1)).addBufferListener(channel2); + verify(bufferPool, times(1)).addBufferListener(channel3); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel3.getNumberOfAvailableBuffers()); + + // Recycle three floating buffers to trigger notify buffer available + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + verify(channel1, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel2, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel3, times(1)).notifyBufferAvailable(any(Buffer.class)); + assertEquals("There should be 3 buffers available in the channel", 3, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel3.getNumberOfAvailableBuffers()); + + } finally { + // Release all the buffer resources + channel1.releaseAllResources(); + channel2.releaseAllResources(); + channel3.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); } - inputChannel.assignExclusiveSegments(exclusiveSegments); + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread releasing + * the input channel. + */ + @Test + public void testConcurrentOnSenderBacklogAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(130, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 128; + + final ExecutorService executor = Executors.newFixedThreadPool(2); + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final Callable requestBufferTask = new Callable() { + @Override + public Void call() throws Exception { + while (true) { + for (int j = 1; j <= numFloatingBuffers; j++) { + inputChannel.onSenderBacklog(j); + } - assertEquals("The number of available buffers is not equal to the assigned amount.", - numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers()); + if (inputChannel.isReleased()) { + return null; + } + } + } + }; + + final Callable releaseTask = new Callable() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + submitTasksAndWaitForResults(executor, new Callable[]{requestBufferTask, releaseTask}); + + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 130 buffers available in local pool.", + 130, bufferPool.getNumberOfAvailableMemorySegments() + networkBufferPool.getNumberOfAvailableMemorySegments()); - // Release this channel - inputChannel.releaseAllResources(); + } finally { + // Release all the buffer resources once exception + if (!inputChannel.isReleased()) { + inputChannel.releaseAllResources(); + } + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); - assertEquals("Resource leak after channel is released.", - 0, inputChannel.getNumberOfAvailableBuffers()); - verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class)); + executor.shutdown(); + } + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread recycling + * floating or exclusive buffers. + */ + @Test + public void testConcurrentOnSenderBacklogAndRecycle() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(248, 32); + final int numExclusiveSegments = 120; + final int numFloatingBuffers = 128; + final int backlog = 128; + + final ExecutorService executor = Executors.newFixedThreadPool(3); + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + final Callable requestBufferTask = new Callable() { + @Override + public Void call() throws Exception { + for (int j = 1; j <= backlog; j++) { + inputChannel.onSenderBacklog(j); + } + + return null; + } + }; + + // Submit tasks and wait to finish + submitTasksAndWaitForResults(executor, new Callable[]{ + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + requestBufferTask}); + + assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.", + inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be no buffers available in local pool.", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); + } + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * recycling the exclusive or floating buffers and some other thread releasing the + * input channel. + */ + @Test + public void testConcurrentRecycleAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(248, 32); + final int numExclusiveSegments = 120; + final int numFloatingBuffers = 128; + + final ExecutorService executor = Executors.newFixedThreadPool(3); + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + final Callable releaseTask = new Callable() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + submitTasksAndWaitForResults(executor, new Callable[]{ + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + releaseTask}); + + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numFloatingBuffers + " buffers available in local pool.", + numFloatingBuffers, bufferPool.getNumberOfAvailableMemorySegments()); + assertEquals("There should be " + numExclusiveSegments + " buffers available in global pool.", + numExclusiveSegments, networkBufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources once exception + if (!inputChannel.isReleased()) { + inputChannel.releaseAllResources(); + } + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); + } } // --------------------------------------------------------------------------------------------- + private SingleInputGate createSingleInputGate() { + return new SingleInputGate( + "InputGate", + new JobID(), + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED_CREDIT_BASED, + 0, + 1, + mock(TaskActions.class), + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + } + private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws IOException, InterruptedException { @@ -403,4 +890,78 @@ private RemoteInputChannel createRemoteInputChannel( initialAndMaxRequestBackoff._2(), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); } + + /** + * Requests the exclusive buffers from input channel first and then recycles them by a callable task. + * + * @param inputChannel The input channel that exclusive buffers request from. + * @param numExclusiveSegments The number of exclusive buffers to request. + * @return The callable task to recycle exclusive buffers. + */ + private Callable recycleExclusiveBufferTask(RemoteInputChannel inputChannel, int numExclusiveSegments) { + final List exclusiveBuffers = new ArrayList<>(numExclusiveSegments); + // Exhaust all the exclusive buffers + for (int i = 0; i < numExclusiveSegments; i++) { + Buffer buffer = inputChannel.requestBuffer(); + assertNotNull(buffer); + exclusiveBuffers.add(buffer); + } + + return new Callable() { + @Override + public Void call() throws Exception { + for (Buffer buffer : exclusiveBuffers) { + buffer.recycle(); + } + + return null; + } + }; + } + + /** + * Requests the floating buffers from pool first and then recycles them by a callable task. + * + * @param bufferPool The buffer pool that floating buffers request from. + * @param numFloatingBuffers The number of floating buffers to request. + * @return The callable task to recycle floating buffers. + */ + private Callable recycleFloatingBufferTask(BufferPool bufferPool, int numFloatingBuffers) throws Exception { + final List floatingBuffers = new ArrayList<>(numFloatingBuffers); + // Exhaust all the floating buffers + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + return new Callable() { + @Override + public Void call() throws Exception { + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + return null; + } + }; + } + + /** + * Submits all the callable tasks to the executor and waits for the results. + * + * @param executor The executor service for running tasks. + * @param tasks The callable tasks to be submitted and executed. + */ + private void submitTasksAndWaitForResults(ExecutorService executor, Callable[] tasks) throws Exception { + final List results = Lists.newArrayListWithCapacity(tasks.length); + + for(Callable task : tasks) { + results.add(executor.submit(task)); + } + + for (Future result : results) { + result.get(); + } + } } From 7c364f84df0969cc135c5a9835681c9d2d8e678a Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Thu, 28 Sep 2017 23:39:26 +0800 Subject: [PATCH 2/4] [FLINK-7416][network] Implement Netty receiver outgoing pipeline for credit-based --- .../netty/CreditBasedClientHandler.java | 108 ++++++- .../io/network/netty/NettyMessage.java | 64 ++++ .../network/netty/PartitionRequestClient.java | 7 + .../netty/PartitionRequestClientHandler.java | 7 + .../partition/consumer/InputChannel.java | 4 + .../consumer/RemoteInputChannel.java | 41 ++- .../netty/NettyMessageSerializationTest.java | 9 + .../PartitionRequestClientHandlerTest.java | 276 ++++++++++++++++-- .../consumer/RemoteInputChannelTest.java | 21 +- 9 files changed, 499 insertions(+), 38 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java index 1f1858843efc8..f5279bff1b5a0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java @@ -25,10 +25,14 @@ import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException; import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException; import org.apache.flink.runtime.io.network.netty.exception.TransportException; +import org.apache.flink.runtime.io.network.netty.NettyMessage.AddCredit; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; @@ -37,6 +41,7 @@ import java.io.IOException; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; @@ -52,14 +57,23 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { /** Channels, which already requested partitions from the producers. */ private final ConcurrentMap inputChannels = new ConcurrentHashMap<>(); + /** Channels, which will notify the producers about unannounced credit. */ + private final ArrayDeque inputChannelsWithCredit = new ArrayDeque<>(); + private final AtomicReference channelError = new AtomicReference<>(); + private final ChannelFutureListener writeListener = new WriteAndFlushNextMessageIfPossibleListener(); + /** * Set of cancelled partition requests. A request is cancelled iff an input channel is cleared * while data is still coming in for this channel. */ private final ConcurrentMap cancelled = new ConcurrentHashMap<>(); + /** + * The channel handler context is initialized in channel active event by netty thread, the context may also + * be accessed by task thread or canceler thread to cancel partition request during releasing resources. + */ private volatile ChannelHandlerContext ctx; // ------------------------------------------------------------------------ @@ -88,6 +102,22 @@ void cancelRequestFor(InputChannelID inputChannelId) { } } + /** + * The credit begins to announce after receiving the sender's backlog from buffer response. + * Than means it should only happen after some interactions with the channel to make sure + * the context will not be null. + * + * @param inputChannel The input channel with unannounced credits. + */ + void notifyCreditAvailable(final RemoteInputChannel inputChannel) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + ctx.pipeline().fireUserEventTriggered(inputChannel); + } + }); + } + // ------------------------------------------------------------------------ // Network events // ------------------------------------------------------------------------ @@ -123,7 +153,6 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause instanceof TransportException) { notifyAllChannelsOfErrorAndClose(cause); } else { @@ -152,6 +181,29 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof RemoteInputChannel) { + // Queue an input channel for available credits announcement. + // If the queue is empty, we try to trigger the actual write. Otherwise + // this will be handled by the writeAndFlushNextMessageIfPossible calls. + boolean triggerWrite = inputChannelsWithCredit.isEmpty(); + + inputChannelsWithCredit.add((RemoteInputChannel) msg); + + if (triggerWrite) { + writeAndFlushNextMessageIfPossible(ctx.channel()); + } + } else { + ctx.fireUserEventTriggered(msg); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + writeAndFlushNextMessageIfPossible(ctx.channel()); + } + private void notifyAllChannelsOfErrorAndClose(Throwable cause) { if (channelError.compareAndSet(null, cause)) { try { @@ -163,6 +215,7 @@ private void notifyAllChannelsOfErrorAndClose(Throwable cause) { LOG.warn("An Exception was thrown during error notification of a remote input channel.", t); } finally { inputChannels.clear(); + inputChannelsWithCredit.clear(); if (ctx != null) { ctx.close(); @@ -274,4 +327,57 @@ private void decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessage.B bufferOrEvent.releaseBuffer(); } } + + /** + * Fetches one un-released input channel from the queue and writes the + * unannounced credits immediately. After this is done, we will continue + * with the next input channel via listener's callback. + */ + private void writeAndFlushNextMessageIfPossible(Channel channel) { + if (channelError.get() != null || !channel.isWritable()) { + return; + } + + while (true) { + RemoteInputChannel inputChannel = inputChannelsWithCredit.poll(); + + // The input channel may be null because of the write callbacks + // that are executed after each write. + if (inputChannel == null) { + return; + } + + //It is no need to notify credit for the released channel. + if (!inputChannel.isReleased()) { + AddCredit msg = new AddCredit( + inputChannel.getPartitionId(), + inputChannel.getAndResetUnannouncedCredit(), + inputChannel.getInputChannelId()); + + // Write and flush and wait until this is done before + // trying to continue with the next input channel. + channel.writeAndFlush(msg).addListener(writeListener); + + return; + } + } + } + + private class WriteAndFlushNextMessageIfPossibleListener implements ChannelFutureListener { + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + try { + if (future.isSuccess()) { + writeAndFlushNextMessageIfPossible(future.channel()); + } else if (future.cause() != null) { + notifyAllChannelsOfErrorAndClose(future.cause()); + } else { + notifyAllChannelsOfErrorAndClose(new IllegalStateException("Sending cancelled by user.")); + } + } catch (Throwable t) { + notifyAllChannelsOfErrorAndClose(t); + } + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index db1b899b83209..cffad83f21f95 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -198,6 +198,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) case CloseRequest.ID: decodedMsg = CloseRequest.readFrom(msg); break; + case AddCredit.ID: + decodedMsg = AddCredit.readFrom(msg); + break; default: throw new ProtocolException("Received unknown message from producer: " + msg); } @@ -584,4 +587,65 @@ static CloseRequest readFrom(@SuppressWarnings("unused") ByteBuf buffer) throws return new CloseRequest(); } } + + /** + * Incremental credit announcement from the client to the server. + */ + static class AddCredit extends NettyMessage { + + private static final byte ID = 6; + + final ResultPartitionID partitionId; + + final int credit; + + final InputChannelID receiverId; + + AddCredit(ResultPartitionID partitionId, int credit, InputChannelID receiverId) { + checkArgument(credit > 0, "The announced credit should be greater than 0"); + + this.partitionId = partitionId; + this.credit = credit; + this.receiverId = receiverId; + } + + @Override + ByteBuf write(ByteBufAllocator allocator) throws IOException { + ByteBuf result = null; + + try { + result = allocateBuffer(allocator, ID, 16 + 16 + 4 + 16); + + partitionId.getPartitionId().writeTo(result); + partitionId.getProducerId().writeTo(result); + result.writeInt(credit); + receiverId.writeTo(result); + + return result; + } + catch (Throwable t) { + if (result != null) { + result.release(); + } + + throw new IOException(t); + } + } + + static AddCredit readFrom(ByteBuf buffer) { + ResultPartitionID partitionId = + new ResultPartitionID( + IntermediateResultPartitionID.fromByteBuf(buffer), + ExecutionAttemptID.fromByteBuf(buffer)); + int credit = buffer.readInt(); + InputChannelID receiverId = InputChannelID.fromByteBuf(buffer); + + return new AddCredit(partitionId, credit, receiverId); + } + + @Override + public String toString() { + return String.format("AddCredit(%s : %d)", receiverId, credit); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java index 8dbc6b7a02c1e..12a9531784d8f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java @@ -167,6 +167,13 @@ public void operationComplete(ChannelFuture future) throws Exception { }); } + public void notifyCreditAvailable(RemoteInputChannel inputChannel) { + // We should skip the notification if the client is already closed. + if (!closeReferenceCounter.isDisposed()) { + partitionRequestHandler.notifyCreditAvailable(inputChannel); + } + } + public void close(RemoteInputChannel inputChannel) throws IOException { partitionRequestHandler.removeInputChannel(inputChannel); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java index ab4798e21720c..e50c0592c293b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java @@ -330,6 +330,13 @@ else if (bufferProvider.isDestroyed()) { } } + /** + * This class would be replaced by CreditBasedClientHandler in the final, + * so we only implement this method in CreditBasedClientHandler. + */ + void notifyCreditAvailable(RemoteInputChannel inputChannel) { + } + private class AsyncErrorNotificationTask implements Runnable { private final Throwable error; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java index f46abfdc154a6..68b05d45aa8a7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java @@ -100,6 +100,10 @@ int getChannelIndex() { return channelIndex; } + public ResultPartitionID getPartitionId() { + return partitionId; + } + /** * Notifies the owning {@link SingleInputGate} that this channel became non-empty. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 02c7b34863ed4..7605075c6f150 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -154,8 +154,9 @@ void assignExclusiveSegments(List segments) { /** * Requests a remote subpartition. */ + @VisibleForTesting @Override - void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException { + public void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException { if (partitionRequestClient == null) { // Create a client and request the partition partitionRequestClient = connectionManager @@ -279,10 +280,15 @@ public String toString() { // ------------------------------------------------------------------------ /** - * Enqueue this input channel in the pipeline for sending unannounced credits to producer. + * Enqueue this input channel in the pipeline for notifying the producer of unannounced credit. */ void notifyCreditAvailable() { - //TODO in next PR + checkState(partitionRequestClient != null, "Tried to send task event to producer before requesting a queue."); + + // We should skip the notification if this channel is already released. + if (!isReleased.get()) { + partitionRequestClient.notifyCreditAvailable(this); + } } /** @@ -320,11 +326,14 @@ public int getNumberOfAvailableBuffers() { } } - @VisibleForTesting public int getNumberOfRequiredBuffers() { return numRequiredBuffers; } + public int getSenderBacklog() { + return numRequiredBuffers - initialCredit; + } + /** * The Buffer pool notifies this channel of an available floating buffer. If the channel is released or * currently does not need extra buffers, the buffer should be recycled to the buffer pool. Otherwise, @@ -379,6 +388,29 @@ public void notifyBufferDestroyed() { // Network I/O notifications (called by network I/O thread) // ------------------------------------------------------------------------ + /** + * Gets the currently unannounced credit. + * + * @return Credit which was not announced to the sender yet. + */ + public int getUnannouncedCredit() { + return unannouncedCredit.get(); + } + + /** + * Gets the unannounced credit and resets it to 0 atomically. + * + * @return Credit which was not announced to the sender yet. + */ + public int getAndResetUnannouncedCredit() { + return unannouncedCredit.getAndSet(0); + } + + /** + * Gets the current number of received buffers which have not been processed yet. + * + * @return Buffers queued for processing. + */ public int getNumberOfQueuedBuffers() { synchronized (receivedBuffers) { return receivedBuffers.size(); @@ -426,7 +458,6 @@ public Buffer requestBuffer() { * * @param backlog The number of unsent buffers in the producer's sub partition. */ - @VisibleForTesting void onSenderBacklog(int backlog) throws IOException { int numRequestedBuffers = 0; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java index 8c87cebca2471..98614bcbe6234 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java @@ -158,6 +158,15 @@ public void testEncodeDecode() { assertEquals(expected.getClass(), actual.getClass()); } + + { + NettyMessage.AddCredit expected = new NettyMessage.AddCredit(new ResultPartitionID(new IntermediateResultPartitionID(), new ExecutionAttemptID()), random.nextInt(Integer.MAX_VALUE) + 1, new InputChannelID()); + NettyMessage.AddCredit actual = encodeAndDecode(expected); + + assertEquals(expected.partitionId, actual.partitionId); + assertEquals(expected.credit, actual.credit); + assertEquals(expected.receiverId, actual.receiverId); + } } @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java index d3ff6c26afc26..42a5f11bbb7f0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java @@ -18,34 +18,53 @@ package org.apache.flink.runtime.io.network.netty; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.io.network.ConnectionID; +import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferListener; +import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.netty.NettyMessage.BufferResponse; import org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse; +import org.apache.flink.runtime.io.network.netty.NettyMessage.AddCredit; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.util.TestBufferFactory; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; import org.junit.Test; import java.io.IOException; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.hamcrest.Matchers.instanceOf; public class PartitionRequestClientHandlerTest { @@ -74,7 +93,7 @@ public void testReleaseInputChannelDuringDecode() throws Exception { when(inputChannel.getBufferProvider()).thenReturn(bufferProvider); final BufferResponse receivedBuffer = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); client.addInputChannel(inputChannel); @@ -122,21 +141,33 @@ public void testReceiveEmptyBuffer() throws Exception { */ @Test public void testReceiveBuffer() throws Exception { - final Buffer buffer = TestBufferFactory.createBuffer(); - final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); - when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); - when(inputChannel.requestBuffer()).thenReturn(buffer); - - final int backlog = 2; - final BufferResponse bufferResponse = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), backlog); - - final CreditBasedClientHandler client = new CreditBasedClientHandler(); - client.addInputChannel(inputChannel); - - client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); - - verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.getPartitionId().getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(8, 8); + inputGate.setBufferPool(bufferPool); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + handler.addInputChannel(inputChannel); + + final int backlog = 2; + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel.getInputChannelId(), backlog); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + assertEquals(1, inputChannel.getNumberOfQueuedBuffers()); + assertEquals(2, inputChannel.getSenderBacklog()); + } finally { + // Release all the buffer resources + inputGate.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** @@ -145,17 +176,18 @@ public void testReceiveBuffer() throws Exception { */ @Test public void testThrowExceptionForNoAvailableBuffer() throws Exception { - final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); - when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); - when(inputChannel.requestBuffer()).thenReturn(null); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); - final BufferResponse bufferResponse = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + handler.addInputChannel(inputChannel); - final CreditBasedClientHandler client = new CreditBasedClientHandler(); - client.addInputChannel(inputChannel); + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); - client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse); verify(inputChannel, times(1)).onError(any(IllegalStateException.class)); } @@ -208,8 +240,200 @@ public void testCancelBeforeActive() throws Exception { client.cancelRequestFor(inputChannel.getInputChannelId()); } + /** + * Verifies that {@link RemoteInputChannel} is enqueued in the pipeline for notifying credits, + * and verifies the behaviour of credit notification by triggering channel's writability changed. + */ + @Test + public void testNotifyCreditAvailable() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel1 = createRemoteInputChannel(inputGate); + final RemoteInputChannel inputChannel2 = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel1.getPartitionId().getPartitionId(), inputChannel1); + inputGate.setInputChannel(inputChannel2.getPartitionId().getPartitionId(), inputChannel2); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6); + inputGate.setBufferPool(bufferPool); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to add input channels in CreditBasedClientHandler explicitly + inputChannel1.requestSubpartition(0); + inputChannel2.requestSubpartition(0); + handler.addInputChannel(inputChannel1); + handler.addInputChannel(inputChannel2); + + // The buffer response will take one available buffer from input channel, and it will trigger + // requesting (backlog + numExclusiveBuffers - numAvailableBuffers) floating buffers + final BufferResponse bufferResponse1 = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel1.getInputChannelId(), 1); + final BufferResponse bufferResponse2 = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel2.getInputChannelId(), 1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse2); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to notify credit available in CreditBasedClientHandler explicitly + handler.notifyCreditAvailable(inputChannel1); + handler.notifyCreditAvailable(inputChannel2); + + assertEquals(2, inputChannel1.getUnannouncedCredit()); + assertEquals(2, inputChannel2.getUnannouncedCredit()); + + channel.runPendingTasks(); + + // The two input channels should notify credits via writable channel + assertTrue(channel.isWritable()); + Object readFromOutbound = channel.readOutbound(); + assertThat(readFromOutbound, instanceOf(AddCredit.class)); + assertEquals(2, ((AddCredit) readFromOutbound).credit); + readFromOutbound = channel.readOutbound(); + assertThat(readFromOutbound, instanceOf(AddCredit.class)); + assertEquals(2, ((AddCredit) readFromOutbound).credit); + assertNull(channel.readOutbound()); + + final int highWaterMark = channel.config().getWriteBufferHighWaterMark(); + // Set the writer index to the high water mark to ensure that all bytes are written + // to the wire although the buffer is "empty". + ByteBuf channelBlockingBuffer = Unpooled.buffer(highWaterMark).writerIndex(highWaterMark); + channel.write(channelBlockingBuffer); + + // Trigger notify credits available via buffer response on the condition of un-writable channel + final BufferResponse bufferResponse3 = createBufferResponse( + TestBufferFactory.createBuffer(32), 1, inputChannel1.getInputChannelId(), 1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse3); + handler.notifyCreditAvailable(inputChannel1); + + assertEquals(1, inputChannel1.getUnannouncedCredit()); + assertEquals(0, inputChannel2.getUnannouncedCredit()); + + channel.runPendingTasks(); + + // The input channel will not notify credits via un-writable channel + assertFalse(channel.isWritable()); + assertNull(channel.readOutbound()); + + // Flush the buffer to make the channel writable again + channel.flush(); + assertSame(channelBlockingBuffer, channel.readOutbound()); + + // The input channel should notify credits via channel's writability changed event + assertTrue(channel.isWritable()); + readFromOutbound = channel.readOutbound(); + assertThat(readFromOutbound, instanceOf(AddCredit.class)); + assertEquals(1, ((AddCredit) readFromOutbound).credit); + assertEquals(0, inputChannel1.getUnannouncedCredit()); + assertEquals(0, inputChannel2.getUnannouncedCredit()); + + // no more messages + assertNull(channel.readOutbound()); + } finally { + // Release all the buffer resources + inputGate.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + + /** + * Verifies that {@link RemoteInputChannel} is enqueued in the pipeline, but {@link AddCredit} + * message is not sent actually when this input channel is released. + */ + @Test + public void testNotifyCreditAvailableAfterReleased() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.getPartitionId().getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6); + inputGate.setBufferPool(bufferPool); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to add input channels in CreditBasedClientHandler explicitly + inputChannel.requestSubpartition(0); + handler.addInputChannel(inputChannel); + + // Trigger request floating buffers via buffer response to notify credits available + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel.getInputChannelId(), 1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + assertEquals(2, inputChannel.getUnannouncedCredit()); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to notify credit available in CreditBasedClientHandler explicitly + handler.notifyCreditAvailable(inputChannel); + + // Release the input channel + inputGate.releaseAllResources(); + + channel.runPendingTasks(); + + // It will not notify credits for released input channel + assertNull(channel.readOutbound()); + } finally { + // Release all the buffer resources + inputGate.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + // --------------------------------------------------------------------------------------------- + /** + * Creates and returns the single input gate for credit-based testing. + * + * @return The new created single input gate. + */ + private SingleInputGate createSingleInputGate() { + return new SingleInputGate( + "InputGate", + new JobID(), + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED_CREDIT_BASED, + 0, + 1, + mock(TaskActions.class), + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + } + + /** + * Creates and returns a remote input channel for the specific input gate. + * + * @param inputGate The input gate owns the created input channel. + * @return The new created remote input channel. + */ + private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws Exception { + final ConnectionManager connectionManager = mock(ConnectionManager.class); + final PartitionRequestClient partitionRequestClient = mock(PartitionRequestClient.class); + when(connectionManager.createPartitionRequestClient(any(ConnectionID.class))) + .thenReturn(partitionRequestClient); + + return new RemoteInputChannel( + inputGate, + 0, + new ResultPartitionID(), + mock(ConnectionID.class), + connectionManager, + 0, + 0, + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + } + /** * Returns a deserialized buffer message as it would be received during runtime. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index 863f8865c6f40..eab1d89f63b90 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -328,6 +328,7 @@ public void testAvailableBuffersLessThanRequiredBuffers() throws Exception { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); // Prepare the exclusive and floating buffers to verify recycle logic later final Buffer exclusiveBuffer = inputChannel.requestBuffer(); @@ -449,6 +450,7 @@ public void testAvailableBuffersEqualToRequiredBuffers() throws Exception { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); // Prepare the exclusive and floating buffers to verify recycle logic later final Buffer exclusiveBuffer = inputChannel.requestBuffer(); @@ -526,6 +528,7 @@ public void testAvailableBuffersMoreThanRequiredBuffers() throws Exception { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); // Prepare the exclusive and floating buffers to verify recycle logic later final Buffer exclusiveBuffer = inputChannel.requestBuffer(); @@ -621,6 +624,9 @@ public void testFairDistributionFloatingBuffers() throws Exception { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + channel1.requestSubpartition(0); + channel2.requestSubpartition(0); + channel3.requestSubpartition(0); // Exhaust all the floating buffers final List floatingBuffers = new ArrayList<>(numFloatingBuffers); @@ -690,6 +696,7 @@ public void testConcurrentOnSenderBacklogAndRelease() throws Exception { final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); final Callable requestBufferTask = new Callable() { @Override @@ -758,6 +765,7 @@ public void testConcurrentOnSenderBacklogAndRecycle() throws Exception { final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + inputChannel.requestSubpartition(0); final Callable requestBufferTask = new Callable() { @Override @@ -772,9 +780,9 @@ public Void call() throws Exception { // Submit tasks and wait to finish submitTasksAndWaitForResults(executor, new Callable[]{ - recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), - recycleFloatingBufferTask(bufferPool, numFloatingBuffers), - requestBufferTask}); + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + requestBufferTask}); assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.", inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers()); @@ -813,6 +821,7 @@ public void testConcurrentRecycleAndRelease() throws Exception { final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + inputChannel.requestSubpartition(0); final Callable releaseTask = new Callable() { @Override @@ -825,9 +834,9 @@ public Void call() throws Exception { // Submit tasks and wait to finish submitTasksAndWaitForResults(executor, new Callable[]{ - recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), - recycleFloatingBufferTask(bufferPool, numFloatingBuffers), - releaseTask}); + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + releaseTask}); assertEquals("There should be no buffers available in the channel.", 0, inputChannel.getNumberOfAvailableBuffers()); From cfe85ced9a4c5bbc342e527d8b59e17ecf1cdad6 Mon Sep 17 00:00:00 2001 From: Piotr Nowojski Date: Mon, 18 Dec 2017 15:26:20 +0100 Subject: [PATCH 3/4] [hotfix][network] Drop redundant this reference usages --- .../SpanningRecordSerializer.java | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java index 87b9e4cb23ed7..7394f83712f8d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java @@ -59,14 +59,14 @@ public class SpanningRecordSerializer implements R private int limit; public SpanningRecordSerializer() { - this.serializationBuffer = new DataOutputSerializer(128); + serializationBuffer = new DataOutputSerializer(128); - this.lengthBuffer = ByteBuffer.allocate(4); - this.lengthBuffer.order(ByteOrder.BIG_ENDIAN); + lengthBuffer = ByteBuffer.allocate(4); + lengthBuffer.order(ByteOrder.BIG_ENDIAN); // ensure initial state with hasRemaining false (for correct setNextBuffer logic) - this.dataBuffer = this.serializationBuffer.wrapAsByteBuffer(); - this.lengthBuffer.position(4); + dataBuffer = serializationBuffer.wrapAsByteBuffer(); + lengthBuffer.position(4); } /** @@ -81,50 +81,50 @@ public SpanningRecordSerializer() { @Override public SerializationResult addRecord(T record) throws IOException { if (CHECKED) { - if (this.dataBuffer.hasRemaining()) { + if (dataBuffer.hasRemaining()) { throw new IllegalStateException("Pending serialization of previous record."); } } - this.serializationBuffer.clear(); - this.lengthBuffer.clear(); + serializationBuffer.clear(); + lengthBuffer.clear(); // write data and length - record.write(this.serializationBuffer); + record.write(serializationBuffer); - int len = this.serializationBuffer.length(); - this.lengthBuffer.putInt(0, len); + int len = serializationBuffer.length(); + lengthBuffer.putInt(0, len); - this.dataBuffer = this.serializationBuffer.wrapAsByteBuffer(); + dataBuffer = serializationBuffer.wrapAsByteBuffer(); // Copy from intermediate buffers to current target memory segment - copyToTargetBufferFrom(this.lengthBuffer); - copyToTargetBufferFrom(this.dataBuffer); + copyToTargetBufferFrom(lengthBuffer); + copyToTargetBufferFrom(dataBuffer); return getSerializationResult(); } @Override public SerializationResult setNextBuffer(Buffer buffer) throws IOException { - this.targetBuffer = buffer; - this.position = 0; - this.limit = buffer.getSize(); + targetBuffer = buffer; + position = 0; + limit = buffer.getSize(); - if (this.lengthBuffer.hasRemaining()) { - copyToTargetBufferFrom(this.lengthBuffer); + if (lengthBuffer.hasRemaining()) { + copyToTargetBufferFrom(lengthBuffer); } - if (this.dataBuffer.hasRemaining()) { - copyToTargetBufferFrom(this.dataBuffer); + if (dataBuffer.hasRemaining()) { + copyToTargetBufferFrom(dataBuffer); } SerializationResult result = getSerializationResult(); // make sure we don't hold onto the large buffers for too long if (result.isFullRecord()) { - this.serializationBuffer.clear(); - this.serializationBuffer.pruneBuffer(); - this.dataBuffer = this.serializationBuffer.wrapAsByteBuffer(); + serializationBuffer.clear(); + serializationBuffer.pruneBuffer(); + dataBuffer = serializationBuffer.wrapAsByteBuffer(); } return result; @@ -137,22 +137,22 @@ public SerializationResult setNextBuffer(Buffer buffer) throws IOException { * @param source the {@link ByteBuffer} to copy data from */ private void copyToTargetBufferFrom(ByteBuffer source) { - if (this.targetBuffer == null) { + if (targetBuffer == null) { return; } int needed = source.remaining(); - int available = this.limit - this.position; + int available = limit - position; int toCopy = Math.min(needed, available); - this.targetBuffer.getMemorySegment().put(this.position, source, toCopy); + targetBuffer.getMemorySegment().put(position, source, toCopy); - this.position += toCopy; + position += toCopy; } private SerializationResult getSerializationResult() { - if (!this.dataBuffer.hasRemaining() && !this.lengthBuffer.hasRemaining()) { - return (this.position < this.limit) + if (!dataBuffer.hasRemaining() && !lengthBuffer.hasRemaining()) { + return (position < limit) ? SerializationResult.FULL_RECORD : SerializationResult.FULL_RECORD_MEMORY_SEGMENT_FULL; } @@ -166,8 +166,8 @@ public Buffer getCurrentBuffer() { return null; } - this.targetBuffer.setSize(this.position); - return this.targetBuffer; + targetBuffer.setSize(position); + return targetBuffer; } @Override @@ -179,19 +179,19 @@ public void clearCurrentBuffer() { @Override public void clear() { - this.targetBuffer = null; - this.position = 0; - this.limit = 0; + targetBuffer = null; + position = 0; + limit = 0; // ensure clear state with hasRemaining false (for correct setNextBuffer logic) - this.dataBuffer.position(this.dataBuffer.limit()); - this.lengthBuffer.position(4); + dataBuffer.position(dataBuffer.limit()); + lengthBuffer.position(4); } @Override public boolean hasData() { // either data in current target buffer or intermediate buffers - return this.position > 0 || (this.lengthBuffer.hasRemaining() || this.dataBuffer.hasRemaining()); + return position > 0 || (lengthBuffer.hasRemaining() || dataBuffer.hasRemaining()); } @Override From dc5427696476d898da6ece19c7e05f6063ed98ae Mon Sep 17 00:00:00 2001 From: Piotr Nowojski Date: Mon, 4 Dec 2017 15:13:57 +0100 Subject: [PATCH 4/4] [FLINK-8207][network-tests] Unify TestInfiniteBufferProvider and TestPooledBufferProvider --- .../network/api/writer/RecordWriterTest.java | 5 +- .../SpilledSubpartitionViewTest.java | 22 ++--- .../util/TestInfiniteBufferProvider.java | 81 ------------------- .../util/TestPooledBufferProvider.java | 10 ++- 4 files changed, 17 insertions(+), 101 deletions(-) delete mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestInfiniteBufferProvider.java diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 95090130cc2a4..63540c39b9256 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -36,7 +36,7 @@ import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.io.network.util.TestBufferFactory; -import org.apache.flink.runtime.io.network.util.TestInfiniteBufferProvider; +import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; import org.apache.flink.runtime.io.network.util.TestTaskEvent; import org.apache.flink.runtime.testutils.DiscardingRecycler; import org.apache.flink.types.IntValue; @@ -421,8 +421,7 @@ public void testBroadcastEventBufferReferenceCounting() throws Exception { new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()}; ResultPartitionWriter partition = - createCollectingPartitionWriter(queues, - new TestInfiniteBufferProvider()); + createCollectingPartitionWriter(queues, new TestPooledBufferProvider(Integer.MAX_VALUE)); RecordWriter writer = new RecordWriter<>(partition); writer.broadcastEvent(EndOfPartitionEvent.INSTANCE); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java index b748e1c30f8a6..69d19fccc20f2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SpilledSubpartitionViewTest.java @@ -24,8 +24,8 @@ import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.io.network.util.TestConsumerCallback; -import org.apache.flink.runtime.io.network.util.TestInfiniteBufferProvider; import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; import org.apache.flink.runtime.io.network.util.TestSubpartitionConsumer; @@ -52,9 +52,6 @@ public class SpilledSubpartitionViewTest { private static final IOManager IO_MANAGER = new IOManagerAsync(); - private static final TestInfiniteBufferProvider writerBufferPool = - new TestInfiniteBufferProvider(); - @AfterClass public static void shutdown() { IO_MANAGER.shutdown(); @@ -66,7 +63,7 @@ public void testWriteConsume() throws Exception { final int numberOfBuffersToWrite = 512; // Setup - final BufferFileWriter writer = createWriterAndWriteBuffers(IO_MANAGER, writerBufferPool, numberOfBuffersToWrite); + final BufferFileWriter writer = createWriterAndWriteBuffers(numberOfBuffersToWrite); writer.close(); @@ -94,7 +91,7 @@ public void testConsumeWithFewBuffers() throws Exception { final int numberOfBuffersToWrite = 512; // Setup - final BufferFileWriter writer = createWriterAndWriteBuffers(IO_MANAGER, writerBufferPool, numberOfBuffersToWrite); + final BufferFileWriter writer = createWriterAndWriteBuffers(numberOfBuffersToWrite); writer.close(); @@ -134,8 +131,8 @@ public void testReadMultipleFilesWithSingleBufferPool() throws Exception { // Setup writers = new BufferFileWriter[]{ - createWriterAndWriteBuffers(IO_MANAGER, writerBufferPool, 512), - createWriterAndWriteBuffers(IO_MANAGER, writerBufferPool, 512) + createWriterAndWriteBuffers(512), + createWriterAndWriteBuffers(512) }; readers = new ResultSubpartitionView[writers.length]; @@ -211,15 +208,12 @@ public void testReadMultipleFilesWithSingleBufferPool() throws Exception { * *

Call {@link BufferFileWriter#close()} to ensure that all buffers have been written. */ - static BufferFileWriter createWriterAndWriteBuffers( - IOManager ioManager, - BufferProvider bufferProvider, - int numberOfBuffers) throws IOException { + private static BufferFileWriter createWriterAndWriteBuffers(int numberOfBuffers) throws IOException { - final BufferFileWriter writer = ioManager.createBufferFileWriter(ioManager.createChannel()); + final BufferFileWriter writer = IO_MANAGER.createBufferFileWriter(IO_MANAGER.createChannel()); for (int i = 0; i < numberOfBuffers; i++) { - writer.writeBlock(bufferProvider.requestBuffer()); + writer.writeBlock(TestBufferFactory.createBuffer()); } writer.writeBlock(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestInfiniteBufferProvider.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestInfiniteBufferProvider.java deleted file mode 100644 index ad40a54058b91..0000000000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestInfiniteBufferProvider.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.io.network.util; - -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.BufferListener; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; -import org.apache.flink.runtime.io.network.buffer.BufferRecycler; - -import java.io.IOException; -import java.util.concurrent.ConcurrentLinkedQueue; - -public class TestInfiniteBufferProvider implements BufferProvider { - - private final ConcurrentLinkedQueue buffers = new ConcurrentLinkedQueue(); - - private final TestBufferFactory bufferFactory = new TestBufferFactory( - 32 * 1024, new InfiniteBufferProviderRecycler(buffers)); - - @Override - public Buffer requestBuffer() throws IOException { - Buffer buffer = buffers.poll(); - - if (buffer != null) { - return buffer; - } - - return bufferFactory.create(); - } - - @Override - public Buffer requestBufferBlocking() throws IOException, InterruptedException { - return requestBuffer(); - } - - @Override - public boolean addBufferListener(BufferListener listener) { - return false; - } - - @Override - public boolean isDestroyed() { - return false; - } - - @Override - public int getMemorySegmentSize() { - return bufferFactory.getBufferSize(); - } - - private static class InfiniteBufferProviderRecycler implements BufferRecycler { - - private final ConcurrentLinkedQueue buffers; - - public InfiniteBufferProviderRecycler(ConcurrentLinkedQueue buffers) { - this.buffers = buffers; - } - - @Override - public void recycle(MemorySegment segment) { - buffers.add(new Buffer(segment, this)); - } - } -} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java index c354eeb143996..221a535aad2d9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestPooledBufferProvider.java @@ -28,8 +28,9 @@ import java.io.IOException; import java.util.Queue; -import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.LinkedBlockingDeque; import static org.apache.flink.util.Preconditions.checkArgument; @@ -37,7 +38,7 @@ public class TestPooledBufferProvider implements BufferProvider { private final Object bufferCreationLock = new Object(); - private final ArrayBlockingQueue buffers; + private final BlockingQueue buffers = new LinkedBlockingDeque<>(); private final TestBufferFactory bufferFactory; @@ -49,7 +50,6 @@ public TestPooledBufferProvider(int poolSize) { checkArgument(poolSize > 0); this.poolSize = poolSize; - this.buffers = new ArrayBlockingQueue(poolSize); this.bufferRecycler = new PooledBufferProviderRecycler(buffers); this.bufferFactory = new TestBufferFactory(32 * 1024, bufferRecycler); } @@ -109,6 +109,10 @@ public int getNumberOfAvailableBuffers() { return buffers.size(); } + public int getNumberOfCreatedBuffers() { + return bufferFactory.getNumberOfCreatedBuffers(); + } + private static class PooledBufferProviderRecycler implements BufferRecycler { private final Object listenerRegistrationLock = new Object();