Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-8768][network] Let NettyMessageDecoder inherit from LengthFieldBasedFrameDecoder #5570

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOutboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelPromise;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.MessageToMessageDecoder;

import javax.annotation.Nullable;

Expand All @@ -47,7 +46,6 @@
import java.io.ObjectOutputStream;
import java.net.ProtocolException;
import java.nio.ByteBuffer;
import java.util.List;

import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
Expand Down Expand Up @@ -188,58 +186,81 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
ctx.write(msg, promise);
}
}

// Create the frame length decoder here as it depends on the encoder
//
// +------------------+------------------+--------++----------------+
// | FRAME LENGTH (4) | MAGIC NUMBER (4) | ID (1) || CUSTOM MESSAGE |
// +------------------+------------------+--------++----------------+
static LengthFieldBasedFrameDecoder createFrameLengthDecoder() {
return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, -4, 4);
}
}

@ChannelHandler.Sharable
static class NettyMessageDecoder extends MessageToMessageDecoder<ByteBuf> {
/**
* Message decoder based on netty's {@link LengthFieldBasedFrameDecoder} but avoiding the
* additional memory copy inside {@link #extractFrame(ChannelHandlerContext, ByteBuf, int, int)}
* since we completely decode the {@link ByteBuf} inside {@link #decode(ChannelHandlerContext,
* ByteBuf)} and will not re-use it afterwards.
*
* <p>The frame-length encoder will be based on this transmission scheme created by {@link NettyMessage#allocateBuffer(ByteBufAllocator, byte, int)}:
* <pre>
* +------------------+------------------+--------++----------------+
* | FRAME LENGTH (4) | MAGIC NUMBER (4) | ID (1) || CUSTOM MESSAGE |
* +------------------+------------------+--------++----------------+
* </pre>
*/
static class NettyMessageDecoder extends LengthFieldBasedFrameDecoder {

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) throws Exception {
int magicNumber = msg.readInt();
/**
* Creates a new message decoded with the required frame properties.
*/
NettyMessageDecoder() {
super(Integer.MAX_VALUE, 0, 4, -4, 4);
}

if (magicNumber != MAGIC_NUMBER) {
throw new IllegalStateException("Network stream corrupted: received incorrect magic number.");
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf msg = (ByteBuf) super.decode(ctx, in);
if (msg == null) {
return null;
}

byte msgId = msg.readByte();

final NettyMessage decodedMsg;
switch (msgId) {
case BufferResponse.ID:
decodedMsg = BufferResponse.readFrom(msg);
break;
case PartitionRequest.ID:
decodedMsg = PartitionRequest.readFrom(msg);
break;
case TaskEventRequest.ID:
decodedMsg = TaskEventRequest.readFrom(msg, getClass().getClassLoader());
break;
case ErrorResponse.ID:
decodedMsg = ErrorResponse.readFrom(msg);
break;
case CancelPartitionRequest.ID:
decodedMsg = CancelPartitionRequest.readFrom(msg);
break;
case CloseRequest.ID:
decodedMsg = CloseRequest.readFrom(msg);
break;
case AddCredit.ID:
decodedMsg = AddCredit.readFrom(msg);
break;
default:
throw new ProtocolException("Received unknown message from producer: " + msg);
}
try {
int magicNumber = msg.readInt();

if (magicNumber != MAGIC_NUMBER) {
throw new IllegalStateException(
"Network stream corrupted: received incorrect magic number.");
}

byte msgId = msg.readByte();

final NettyMessage decodedMsg;
switch (msgId) {
case BufferResponse.ID:
decodedMsg = BufferResponse.readFrom(msg);
break;
case PartitionRequest.ID:
decodedMsg = PartitionRequest.readFrom(msg);
break;
case TaskEventRequest.ID:
decodedMsg = TaskEventRequest.readFrom(msg, getClass().getClassLoader());
break;
case ErrorResponse.ID:
decodedMsg = ErrorResponse.readFrom(msg);
break;
case CancelPartitionRequest.ID:
decodedMsg = CancelPartitionRequest.readFrom(msg);
break;
case CloseRequest.ID:
decodedMsg = CloseRequest.readFrom(msg);
break;
case AddCredit.ID:
decodedMsg = AddCredit.readFrom(msg);
break;
default:
throw new ProtocolException(
"Received unknown message from producer: " + msg);
}

out.add(decodedMsg);
return decodedMsg;
} finally {
// ByteToMessageDecoder cleanup (only the BufferResponse holds on to the decoded
// msg but already retain()s the buffer once)
msg.release();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;

import static org.apache.flink.runtime.io.network.netty.NettyMessage.NettyMessageEncoder.createFrameLengthDecoder;

/**
* Defines the server and client channel handlers, i.e. the protocol, used by netty.
*/
Expand All @@ -34,8 +32,6 @@ public class NettyProtocol {
private final NettyMessage.NettyMessageEncoder
messageEncoder = new NettyMessage.NettyMessageEncoder();

private final NettyMessage.NettyMessageDecoder messageDecoder = new NettyMessage.NettyMessageDecoder();

private final ResultPartitionProvider partitionProvider;
private final TaskEventDispatcher taskEventDispatcher;

Expand Down Expand Up @@ -64,14 +60,9 @@ public class NettyProtocol {
* | +----------+----------+ | |
* | /|\ | |
* | | | |
* | +----------+----------+ | |
* | | Message decoder | | |
* | +----------+----------+ | |
* | /|\ | |
* | | | |
* | +----------+----------+ | |
* | | Frame decoder | | |
* | +----------+----------+ | |
* | +-----------+-----------+ | |
* | | Message+Frame decoder | | |
* | +-----------+-----------+ | |
* | /|\ | |
* +---------------+-----------------------------------+---------------+
* | | (1) client request \|/
Expand All @@ -92,8 +83,7 @@ public ChannelHandler[] getServerChannelHandlers() {

return new ChannelHandler[] {
messageEncoder,
createFrameLengthDecoder(),
messageDecoder,
new NettyMessage.NettyMessageDecoder(),
serverHandler,
queueOfPartitionQueues
};
Expand All @@ -115,14 +105,9 @@ public ChannelHandler[] getServerChannelHandlers() {
* | +----------+----------+ +-----------+----------+ |
* | /|\ \|/ |
* | | | |
* | +----------+----------+ | |
* | | Message decoder | | |
* | +----------+----------+ | |
* | /|\ | |
* | | | |
* | +----------+----------+ | |
* | | Frame decoder | | |
* | +----------+----------+ | |
* | +----------+------------+ | |
* | | Message+Frame decoder | | |
* | +----------+------------+ | |
* | /|\ | |
* +---------------+-----------------------------------+---------------+
* | | (3) server response \|/ (2) client request
Expand All @@ -142,8 +127,7 @@ public ChannelHandler[] getClientChannelHandlers() {
new PartitionRequestClientHandler();
return new ChannelHandler[] {
messageEncoder,
createFrameLengthDecoder(),
messageDecoder,
new NettyMessage.NettyMessageDecoder(),
networkClientHandler};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ public class NettyMessageSerializationTest {

private final EmbeddedChannel channel = new EmbeddedChannel(
new NettyMessage.NettyMessageEncoder(), // outbound messages
NettyMessage.NettyMessageEncoder.createFrameLengthDecoder(), // inbound messages
new NettyMessage.NettyMessageDecoder()); // inbound messages

private final Random random = new Random();
Expand Down