Skip to content

Commit

Permalink
[FLINK-7406][network] Implement Netty receiver incoming pipeline for …
Browse files Browse the repository at this point in the history
…credit-based
  • Loading branch information
zhijiangW authored and StefanRRichter committed Jan 8, 2018
1 parent 542419b commit 268867c
Show file tree
Hide file tree
Showing 10 changed files with 1,175 additions and 214 deletions.
@@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.io.network.netty;

import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException;
import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException;
import org.apache.flink.runtime.io.network.netty.exception.TransportException;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;

import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.SocketAddress;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;

/**
* Channel handler to read the messages of buffer response or error response from the
* producer, to write and flush the unannounced credits for the producer.
*/
class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {

private static final Logger LOG = LoggerFactory.getLogger(CreditBasedClientHandler.class);

/** Channels, which already requested partitions from the producers. */
private final ConcurrentMap<InputChannelID, RemoteInputChannel> inputChannels = new ConcurrentHashMap<>();

private final AtomicReference<Throwable> channelError = new AtomicReference<>();

/**
* Set of cancelled partition requests. A request is cancelled iff an input channel is cleared
* while data is still coming in for this channel.
*/
private final ConcurrentMap<InputChannelID, InputChannelID> cancelled = new ConcurrentHashMap<>();

private volatile ChannelHandlerContext ctx;

// ------------------------------------------------------------------------
// Input channel/receiver registration
// ------------------------------------------------------------------------

void addInputChannel(RemoteInputChannel listener) throws IOException {
checkError();

if (!inputChannels.containsKey(listener.getInputChannelId())) {
inputChannels.put(listener.getInputChannelId(), listener);
}
}

void removeInputChannel(RemoteInputChannel listener) {
inputChannels.remove(listener.getInputChannelId());
}

void cancelRequestFor(InputChannelID inputChannelId) {
if (inputChannelId == null || ctx == null) {
return;
}

if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) {
ctx.writeAndFlush(new NettyMessage.CancelPartitionRequest(inputChannelId));
}
}

// ------------------------------------------------------------------------
// Network events
// ------------------------------------------------------------------------

@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
if (this.ctx == null) {
this.ctx = ctx;
}

super.channelActive(ctx);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
// Unexpected close. In normal operation, the client closes the connection after all input
// channels have been removed. This indicates a problem with the remote task manager.
if (!inputChannels.isEmpty()) {
final SocketAddress remoteAddr = ctx.channel().remoteAddress();

notifyAllChannelsOfErrorAndClose(new RemoteTransportException(
"Connection unexpectedly closed by remote task manager '" + remoteAddr + "'. "
+ "This might indicate that the remote task manager was lost.", remoteAddr));
}

super.channelInactive(ctx);
}

/**
* Called on exceptions in the client handler pipeline.
*
* <p>Remote exceptions are received as regular payload.
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {

if (cause instanceof TransportException) {
notifyAllChannelsOfErrorAndClose(cause);
} else {
final SocketAddress remoteAddr = ctx.channel().remoteAddress();

final TransportException tex;

// Improve on the connection reset by peer error message
if (cause instanceof IOException && cause.getMessage().equals("Connection reset by peer")) {
tex = new RemoteTransportException("Lost connection to task manager '" + remoteAddr + "'. " +
"This indicates that the remote task manager was lost.", remoteAddr, cause);
} else {
tex = new LocalTransportException(cause.getMessage(), ctx.channel().localAddress(), cause);
}

notifyAllChannelsOfErrorAndClose(tex);
}
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
try {
decodeMsg(msg);
} catch (Throwable t) {
notifyAllChannelsOfErrorAndClose(t);
}
}

private void notifyAllChannelsOfErrorAndClose(Throwable cause) {
if (channelError.compareAndSet(null, cause)) {
try {
for (RemoteInputChannel inputChannel : inputChannels.values()) {
inputChannel.onError(cause);
}
} catch (Throwable t) {
// We can only swallow the Exception at this point. :(
LOG.warn("An Exception was thrown during error notification of a remote input channel.", t);
} finally {
inputChannels.clear();

if (ctx != null) {
ctx.close();
}
}
}
}

// ------------------------------------------------------------------------

/**
* Checks for an error and rethrows it if one was reported.
*/
private void checkError() throws IOException {
final Throwable t = channelError.get();

if (t != null) {
if (t instanceof IOException) {
throw (IOException) t;
} else {
throw new IOException("There has been an error in the channel.", t);
}
}
}

private void decodeMsg(Object msg) throws Throwable {
final Class<?> msgClazz = msg.getClass();

// ---- Buffer --------------------------------------------------------
if (msgClazz == NettyMessage.BufferResponse.class) {
NettyMessage.BufferResponse bufferOrEvent = (NettyMessage.BufferResponse) msg;

RemoteInputChannel inputChannel = inputChannels.get(bufferOrEvent.receiverId);
if (inputChannel == null) {
bufferOrEvent.releaseBuffer();

cancelRequestFor(bufferOrEvent.receiverId);

return;
}

decodeBufferOrEvent(inputChannel, bufferOrEvent);

} else if (msgClazz == NettyMessage.ErrorResponse.class) {
// ---- Error ---------------------------------------------------------
NettyMessage.ErrorResponse error = (NettyMessage.ErrorResponse) msg;

SocketAddress remoteAddr = ctx.channel().remoteAddress();

if (error.isFatalError()) {
notifyAllChannelsOfErrorAndClose(new RemoteTransportException(
"Fatal error at remote task manager '" + remoteAddr + "'.",
remoteAddr,
error.cause));
} else {
RemoteInputChannel inputChannel = inputChannels.get(error.receiverId);

if (inputChannel != null) {
if (error.cause.getClass() == PartitionNotFoundException.class) {
inputChannel.onFailedPartitionRequest();
} else {
inputChannel.onError(new RemoteTransportException(
"Error at remote task manager '" + remoteAddr + "'.",
remoteAddr,
error.cause));
}
}
}
} else {
throw new IllegalStateException("Received unknown message from producer: " + msg.getClass());
}
}

private void decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessage.BufferResponse bufferOrEvent) throws Throwable {
try {
if (bufferOrEvent.isBuffer()) {
// ---- Buffer ------------------------------------------------

// Early return for empty buffers. Otherwise Netty's readBytes() throws an
// IndexOutOfBoundsException.
if (bufferOrEvent.getSize() == 0) {
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
return;
}

Buffer buffer = inputChannel.requestBuffer();
if (buffer != null) {
buffer.setSize(bufferOrEvent.getSize());
bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer());

inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
} else if (inputChannel.isReleased()) {
cancelRequestFor(bufferOrEvent.receiverId);
} else {
throw new IllegalStateException("No buffer available in credit-based input channel.");
}
} else {
// ---- Event -------------------------------------------------
// TODO We can just keep the serialized data in the Netty buffer and release it later at the reader
byte[] byteArray = new byte[bufferOrEvent.getSize()];
bufferOrEvent.getNettyBuffer().readBytes(byteArray);

MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray);
Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false);

inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
}
} finally {
bufferOrEvent.releaseBuffer();
}
}
}
Expand Up @@ -221,6 +221,8 @@ static class BufferResponse extends NettyMessage {

final int sequenceNumber;

final int backlog;

// ---- Deserialization -----------------------------------------------

final boolean isBuffer;
Expand All @@ -232,7 +234,8 @@ static class BufferResponse extends NettyMessage {

private BufferResponse(
ByteBuf retainedSlice, boolean isBuffer, int sequenceNumber,
InputChannelID receiverId) {
InputChannelID receiverId,
int backlog) {
// When deserializing we first have to request a buffer from the respective buffer
// provider (at the handler) and copy the buffer from Netty's space to ours. Only
// retainedSlice is set in this case.
Expand All @@ -242,15 +245,17 @@ private BufferResponse(
this.isBuffer = isBuffer;
this.sequenceNumber = sequenceNumber;
this.receiverId = checkNotNull(receiverId);
this.backlog = backlog;
}

BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId) {
BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId, int backlog) {
this.buffer = checkNotNull(buffer);
this.retainedSlice = null;
this.isBuffer = buffer.isBuffer();
this.size = buffer.getSize();
this.sequenceNumber = sequenceNumber;
this.receiverId = checkNotNull(receiverId);
this.backlog = backlog;
}

boolean isBuffer() {
Expand Down Expand Up @@ -280,14 +285,15 @@ void releaseBuffer() {
ByteBuf write(ByteBufAllocator allocator) throws IOException {
checkNotNull(buffer, "No buffer instance to serialize.");

int length = 16 + 4 + 1 + 4 + buffer.getSize();
int length = 16 + 4 + 4 + 1 + 4 + buffer.getSize();

ByteBuf result = null;
try {
result = allocateBuffer(allocator, ID, length);

receiverId.writeTo(result);
result.writeInt(sequenceNumber);
result.writeInt(backlog);
result.writeBoolean(buffer.isBuffer());
result.writeInt(buffer.getSize());
result.writeBytes(buffer.getNioBuffer());
Expand All @@ -309,12 +315,13 @@ ByteBuf write(ByteBufAllocator allocator) throws IOException {
static BufferResponse readFrom(ByteBuf buffer) {
InputChannelID receiverId = InputChannelID.fromByteBuf(buffer);
int sequenceNumber = buffer.readInt();
int backlog = buffer.readInt();
boolean isBuffer = buffer.readBoolean();
int size = buffer.readInt();

ByteBuf retainedSlice = buffer.readSlice(size).retain();

return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId);
return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId, backlog);
}
}

Expand Down
Expand Up @@ -276,7 +276,7 @@ private boolean decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessag
// Early return for empty buffers. Otherwise Netty's readBytes() throws an
// IndexOutOfBoundsException.
if (bufferOrEvent.getSize() == 0) {
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber);
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, -1);
return true;
}

Expand All @@ -295,7 +295,7 @@ private boolean decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessag
buffer.setSize(bufferOrEvent.getSize());
bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer());

inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber);
inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1);

return true;
}
Expand All @@ -318,7 +318,7 @@ else if (bufferProvider.isDestroyed()) {
MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray);
Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false);

inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber);
inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1);

return true;
}
Expand Down Expand Up @@ -450,7 +450,7 @@ public void run() {
RemoteInputChannel inputChannel = inputChannels.get(stagedBufferResponse.receiverId);

if (inputChannel != null) {
inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber);
inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber, -1);

success = true;
}
Expand Down
Expand Up @@ -193,7 +193,8 @@ private void writeAndFlushNextMessageIfPossible(final Channel channel) throws IO
BufferResponse msg = new BufferResponse(
next.buffer(),
reader.getSequenceNumber(),
reader.getReceiverId());
reader.getReceiverId(),
0);

if (isEndOfPartitionEvent(next.buffer())) {
reader.notifySubpartitionConsumed();
Expand Down

0 comments on commit 268867c

Please sign in to comment.