Skip to content

Commit

Permalink
Clean up Netty TCP implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
kuujo committed Feb 9, 2015
1 parent 5e734e9 commit e56c7f1
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 62 deletions.
Expand Up @@ -153,7 +153,7 @@ private void recursiveSync(ReplicaInfo member, boolean requireEntries, Completab
} }
} else { } else {
// If the request failed then record the member as INACTIVE. // If the request failed then record the member as INACTIVE.
LOGGER.warn("{} - Sync to {} failed", context.getLocalMember(), member); LOGGER.warn("{} - Sync to {} failed: {}", context.getLocalMember(), member, error.getMessage());
future.completeExceptionally(error); future.completeExceptionally(error);
} }
} }
Expand Down
Expand Up @@ -16,11 +16,14 @@
package net.kuujo.copycat.netty; package net.kuujo.copycat.netty;


import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*; import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.bytes.ByteArrayDecoder;
import io.netty.handler.codec.bytes.ByteArrayEncoder;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import net.kuujo.copycat.protocol.ProtocolClient; import net.kuujo.copycat.protocol.ProtocolClient;
Expand All @@ -41,29 +44,26 @@ public class NettyTcpProtocolClient implements ProtocolClient {
private final String host; private final String host;
private final int port; private final int port;
private final NettyTcpProtocol protocol; private final NettyTcpProtocol protocol;
private EventLoopGroup group;
private Channel channel; private Channel channel;
private ChannelHandlerContext context; private ChannelHandlerContext context;
private final Map<Object, CompletableFuture<ByteBuffer>> responseFutures = new HashMap<>(1000); private final Map<Object, CompletableFuture<ByteBuffer>> responseFutures = new HashMap<>(1000);
private long requestId; private long requestId;


private final ChannelInboundHandlerAdapter channelHandler = new ChannelInboundHandlerAdapter() { private final ChannelInboundHandlerAdapter channelHandler = new SimpleChannelInboundHandler<byte[]>() {
@Override @Override
public void channelActive(ChannelHandlerContext context) { public void channelActive(ChannelHandlerContext context) {
NettyTcpProtocolClient.this.context = context; NettyTcpProtocolClient.this.context = context;
} }

@Override @Override
public void channelRead(ChannelHandlerContext context, Object message) { protected void channelRead0(ChannelHandlerContext context, byte[] message) throws Exception {
ByteBuf response = (ByteBuf) message; ByteBuffer buffer = ByteBuffer.wrap(message);
long responseId = response.readLong(); long responseId = buffer.getLong();
CompletableFuture<ByteBuffer> responseFuture = responseFutures.remove(responseId); CompletableFuture<ByteBuffer> responseFuture = responseFutures.remove(responseId);
if (responseFuture != null) { if (responseFuture != null) {
int length = response.readInt(); responseFuture.complete(buffer.slice());
ByteBuffer buffer = ByteBuffer.allocateDirect(length);
response.readBytes(buffer);
buffer.flip();
responseFuture.complete(buffer);
} }
response.release();
} }
}; };


Expand All @@ -79,11 +79,10 @@ public CompletableFuture<ByteBuffer> write(ByteBuffer request) {
final CompletableFuture<ByteBuffer> future = new CompletableFuture<>(); final CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
if (channel != null) { if (channel != null) {
long requestId = ++this.requestId; long requestId = ++this.requestId;
ByteBuf buffer = context.alloc().buffer(request.remaining() + 12); // Request ID and length ByteBuffer buffer = ByteBuffer.allocate(request.limit() + 8);
buffer.writeLong(requestId); buffer.putLong(requestId);
buffer.writeInt(request.remaining()); buffer.put(request);
buffer.writeBytes(request); channel.writeAndFlush(buffer.array()).addListener((channelFuture) -> {
channel.writeAndFlush(buffer).addListener((channelFuture) -> {
if (channelFuture.isSuccess()) { if (channelFuture.isSuccess()) {
responseFutures.put(requestId, future); responseFutures.put(requestId, future);
} else { } else {
Expand Down Expand Up @@ -116,7 +115,7 @@ public CompletableFuture<Void> connect() {
sslContext = null; sslContext = null;
} }


final EventLoopGroup group = new NioEventLoopGroup(protocol.getThreads()); group = new NioEventLoopGroup(protocol.getThreads());
Bootstrap bootstrap = new Bootstrap(); Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group) bootstrap.group(group)
.channel(NioSocketChannel.class) .channel(NioSocketChannel.class)
Expand All @@ -127,7 +126,11 @@ protected void initChannel(SocketChannel channel) throws Exception {
if (sslContext != null) { if (sslContext != null) {
pipeline.addLast(sslContext.newHandler(channel.alloc(), host, port)); pipeline.addLast(sslContext.newHandler(channel.alloc(), host, port));
} }
pipeline.addLast(channelHandler); pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4));
pipeline.addLast("bytesDecoder", new ByteArrayDecoder());
pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
pipeline.addLast("bytesEncoder", new ByteArrayEncoder());
pipeline.addLast("handler", channelHandler);
} }
}); });


Expand All @@ -148,15 +151,12 @@ protected void initChannel(SocketChannel channel) throws Exception {
bootstrap.option(ChannelOption.SO_KEEPALIVE, true); bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, protocol.getConnectTimeout()); bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, protocol.getConnectTimeout());


bootstrap.connect(host, port).addListener(new ChannelFutureListener() { bootstrap.connect(host, port).addListener((ChannelFutureListener) channelFuture -> {
@Override if (channelFuture.isSuccess()) {
public void operationComplete(ChannelFuture channelFuture) throws Exception { channel = channelFuture.channel();
if (channelFuture.isSuccess()) { future.complete(null);
channel = channelFuture.channel(); } else {
future.complete(null); future.completeExceptionally(channelFuture.cause());
} else {
future.completeExceptionally(channelFuture.cause());
}
} }
}); });
return future; return future;
Expand All @@ -166,15 +166,12 @@ public void operationComplete(ChannelFuture channelFuture) throws Exception {
public CompletableFuture<Void> close() { public CompletableFuture<Void> close() {
final CompletableFuture<Void> future = new CompletableFuture<>(); final CompletableFuture<Void> future = new CompletableFuture<>();
if (channel != null) { if (channel != null) {
channel.close().addListener(new ChannelFutureListener() { channel.close().addListener(channelFuture -> {
@Override group.shutdownGracefully();
public void operationComplete(ChannelFuture channelFuture) throws Exception { if (channelFuture.isSuccess()) {
channel = null; future.complete(null);
if (channelFuture.isSuccess()) { } else {
future.complete(null); future.completeExceptionally(channelFuture.cause());
} else {
future.completeExceptionally(channelFuture.cause());
}
} }
}); });
} else { } else {
Expand Down
Expand Up @@ -16,11 +16,14 @@
package net.kuujo.copycat.netty; package net.kuujo.copycat.netty;


import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*; import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.bytes.ByteArrayDecoder;
import io.netty.handler.codec.bytes.ByteArrayEncoder;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.handler.ssl.util.SelfSignedCertificate;
import net.kuujo.copycat.protocol.ProtocolHandler; import net.kuujo.copycat.protocol.ProtocolHandler;
Expand All @@ -41,6 +44,8 @@ public class NettyTcpProtocolServer implements ProtocolServer {
private final int port; private final int port;
private final NettyTcpProtocol protocol; private final NettyTcpProtocol protocol;
private ProtocolHandler handler; private ProtocolHandler handler;
private EventLoopGroup serverGroup;
private EventLoopGroup workerGroup;
private Channel channel; private Channel channel;


public NettyTcpProtocolServer(String host, int port, NettyTcpProtocol protocol) { public NettyTcpProtocolServer(String host, int port, NettyTcpProtocol protocol) {
Expand Down Expand Up @@ -71,8 +76,8 @@ public synchronized CompletableFuture<Void> listen() {
sslContext = null; sslContext = null;
} }


final EventLoopGroup serverGroup = new NioEventLoopGroup(); serverGroup = new NioEventLoopGroup();
final EventLoopGroup workerGroup = new NioEventLoopGroup(protocol.getThreads()); workerGroup = new NioEventLoopGroup(protocol.getThreads());


final ServerBootstrap bootstrap = new ServerBootstrap(); final ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(serverGroup, workerGroup) bootstrap.group(serverGroup, workerGroup)
Expand All @@ -84,7 +89,11 @@ public void initChannel(SocketChannel channel) throws Exception {
if (sslContext != null) { if (sslContext != null) {
pipeline.addLast(sslContext.newHandler(channel.alloc())); pipeline.addLast(sslContext.newHandler(channel.alloc()));
} }
pipeline.addLast(new ServerHandlerAdapter()); pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4));
pipeline.addLast("bytesDecoder", new ByteArrayDecoder());
pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
pipeline.addLast("bytesEncoder", new ByteArrayEncoder());
pipeline.addLast("handler", new ServerHandler());
} }
}) })
.option(ChannelOption.SO_BACKLOG, 128); .option(ChannelOption.SO_BACKLOG, 128);
Expand Down Expand Up @@ -129,6 +138,8 @@ public CompletableFuture<Void> close() {
final CompletableFuture<Void> future = new CompletableFuture<>(); final CompletableFuture<Void> future = new CompletableFuture<>();
if (channel != null) { if (channel != null) {
channel.close().addListener(channelFuture -> { channel.close().addListener(channelFuture -> {
workerGroup.shutdownGracefully();
serverGroup.shutdownGracefully();
if (channelFuture.isSuccess()) { if (channelFuture.isSuccess()) {
future.complete(null); future.complete(null);
} else { } else {
Expand All @@ -146,24 +157,19 @@ public String toString() {
return getClass().getSimpleName(); return getClass().getSimpleName();
} }


private class ServerHandlerAdapter extends ChannelInboundHandlerAdapter { private class ServerHandler extends SimpleChannelInboundHandler<byte[]> {
@Override @Override
public void channelRead(final ChannelHandlerContext context, Object message) { protected void channelRead0(ChannelHandlerContext context, byte[] message) throws Exception {
ByteBuf request = (ByteBuf) message;
if (handler != null) { if (handler != null) {
long requestId = request.readLong(); ByteBuffer buffer = ByteBuffer.wrap(message);
int length = request.readInt(); long requestId = buffer.getLong();
ByteBuffer buffer = request.nioBuffer(request.readerIndex(), length); handler.apply(buffer.slice()).whenComplete((result, error) -> {
handler.apply(buffer).whenComplete((result, error) -> {
if (error == null) { if (error == null) {
context.channel().eventLoop().execute(() -> { context.channel().eventLoop().execute(() -> {
ByteBuf response = context.alloc().buffer(result.remaining() + 12); // Response ID and length ByteBuffer response = ByteBuffer.allocate(result.limit() + 8);
response.writeLong(requestId); response.putLong(requestId);
response.writeInt(result.remaining()); response.put(result);
response.writeBytes(result); context.writeAndFlush(response.array());
context.writeAndFlush(response).addListener(future -> {
request.release();
});
}); });
} }
}); });
Expand Down
Expand Up @@ -192,6 +192,7 @@ public void testSendReceive() throws Throwable {
client.connect().thenRunAsync(this::resume); client.connect().thenRunAsync(this::resume);
await(5000); await(5000);


expectResume();
client.write(ByteBuffer.wrap("Hello world!".getBytes())).thenAcceptAsync(buffer -> { client.write(ByteBuffer.wrap("Hello world!".getBytes())).thenAcceptAsync(buffer -> {
byte[] bytes = new byte[buffer.remaining()]; byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes); buffer.get(bytes);
Expand All @@ -200,6 +201,7 @@ public void testSendReceive() throws Throwable {
}); });
await(5000); await(5000);


expectResume();
client.write(ByteBuffer.wrap("Hello world!".getBytes())).thenAcceptAsync(buffer -> { client.write(ByteBuffer.wrap("Hello world!".getBytes())).thenAcceptAsync(buffer -> {
byte[] bytes = new byte[buffer.remaining()]; byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes); buffer.get(bytes);
Expand All @@ -208,6 +210,7 @@ public void testSendReceive() throws Throwable {
}); });
await(5000); await(5000);


expectResume();
client.write(ByteBuffer.wrap("Hello world!".getBytes())).thenAcceptAsync(buffer -> { client.write(ByteBuffer.wrap("Hello world!".getBytes())).thenAcceptAsync(buffer -> {
byte[] bytes = new byte[buffer.remaining()]; byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes); buffer.get(bytes);
Expand All @@ -218,11 +221,11 @@ public void testSendReceive() throws Throwable {


expectResume(); expectResume();
client.close().thenRunAsync(this::resume); client.close().thenRunAsync(this::resume);
await(1000); await(2500);


expectResume(); expectResume();
server.close().thenRunAsync(this::resume); server.close().thenRunAsync(this::resume);
await(1000); await(2500);
} }


} }
12 changes: 7 additions & 5 deletions test-tools/src/main/java/net/kuujo/copycat/test/TestCluster.java
Expand Up @@ -29,6 +29,7 @@
* @author <a href="http://github.com/kuujo">Jordan Halterman</a> * @author <a href="http://github.com/kuujo">Jordan Halterman</a>
*/ */
public class TestCluster<T extends Resource<T>> { public class TestCluster<T extends Resource<T>> {
private static int id;
private final List<T> activeResources; private final List<T> activeResources;
private final List<T> passiveResources; private final List<T> passiveResources;


Expand Down Expand Up @@ -102,10 +103,10 @@ public TestCluster<T> build() {


List<T> activeResources = new ArrayList<>(activeMembers); List<T> activeResources = new ArrayList<>(activeMembers);


int i = 1;
Set<String> members = new HashSet<>(activeMembers); Set<String> members = new HashSet<>(activeMembers);
while (i <= activeMembers) { int activeCount = activeMembers + id;
String uri = uriFactory.apply(i++); while (id <= activeCount) {
String uri = uriFactory.apply(id++);
members.add(uri); members.add(uri);
} }


Expand All @@ -115,8 +116,9 @@ public TestCluster<T> build() {
} }


List<T> passiveResources = new ArrayList<>(passiveMembers); List<T> passiveResources = new ArrayList<>(passiveMembers);
while (i <= passiveMembers + activeMembers) { int passiveCount = passiveMembers + id;
String member = uriFactory.apply(i++); while (id <= passiveCount) {
String member = uriFactory.apply(id++);
ClusterConfig cluster = clusterFactory.apply(members).withLocalMember(member); ClusterConfig cluster = clusterFactory.apply(members).withLocalMember(member);
passiveResources.add(resourceFactory.apply(cluster)); passiveResources.add(resourceFactory.apply(cluster));
} }
Expand Down

0 comments on commit e56c7f1

Please sign in to comment.