Skip to content

Commit

Permalink
Revamp WebSocket API, close #1393
Browse files Browse the repository at this point in the history
Motivation:

Current WebSocket API has many limitations:
1. It's not possible to be notified of received fragmented frames, they
are always aggregated
2. It’s not possible to specify extension bits when sending a frame
3. The API for being notified of write completion is cumbersome
4. There are tons of different interfaces to be notified of the
different types of frames
5. Method names are not aligned between `WebSocket` and
`WebSocketListener`
6. There are 2 onClose listeners, the default one doesn't expose the
code and reason

Modifications:

1. Add a new `aggregateWebSocketFrameFragments` config param,
defaulting to true. When false, fragmented frames are not aggregated.
Drop `WebSocketByteFragmentListener` and
`WebSocketTextFragmentListener` and add `finalFragment` and `rsv`
parameters to `WebSocketListener`
2. Provide a complete set of send methods on `WebSocket` supporting
`rsv` and fragmentation/continuation frame.
3. Have send methods return the Netty future so users can register
listeners
4. Drop `WebSocketTextListener`, `WebSocketByteListener`,
`WebSocketPingListener` and `WebSocketPongListener` and add default
methods on `WebSocketListener`. Drop `DefaultWebSocketListener`.
5. Rename all methods to `sendXXXFrame` and `onXXXFrame`
6. Drop `WebSocketCloseCodeReasonListener` and change
`WebSocketListener#onClose` to notify with code and reason.

Result:

More complete and consistent WebSocket support
  • Loading branch information
slandelle committed Apr 19, 2017
1 parent 43b20af commit 98bef40
Show file tree
Hide file tree
Showing 26 changed files with 601 additions and 945 deletions.
Expand Up @@ -279,6 +279,8 @@ public interface AsyncHttpClientConfig {


boolean isValidateResponseHeaders(); boolean isValidateResponseHeaders();


boolean isAggregateWebSocketFrameFragments();

boolean isTcpNoDelay(); boolean isTcpNoDelay();


boolean isSoReuseAddress(); boolean isSoReuseAddress();
Expand Down
Expand Up @@ -75,6 +75,7 @@ public class DefaultAsyncHttpClientConfig implements AsyncHttpClientConfig {
private final boolean keepEncodingHeader; private final boolean keepEncodingHeader;
private final ProxyServerSelector proxyServerSelector; private final ProxyServerSelector proxyServerSelector;
private final boolean validateResponseHeaders; private final boolean validateResponseHeaders;
private final boolean aggregateWebSocketFrameFragments;


// timeouts // timeouts
private final int connectTimeout; private final int connectTimeout;
Expand Down Expand Up @@ -148,6 +149,7 @@ private DefaultAsyncHttpClientConfig(//
boolean keepEncodingHeader,// boolean keepEncodingHeader,//
ProxyServerSelector proxyServerSelector,// ProxyServerSelector proxyServerSelector,//
boolean validateResponseHeaders,// boolean validateResponseHeaders,//
boolean aggregateWebSocketFrameFragments,


// timeouts // timeouts
int connectTimeout,// int connectTimeout,//
Expand Down Expand Up @@ -222,6 +224,7 @@ private DefaultAsyncHttpClientConfig(//
this.keepEncodingHeader = keepEncodingHeader; this.keepEncodingHeader = keepEncodingHeader;
this.proxyServerSelector = proxyServerSelector; this.proxyServerSelector = proxyServerSelector;
this.validateResponseHeaders = validateResponseHeaders; this.validateResponseHeaders = validateResponseHeaders;
this.aggregateWebSocketFrameFragments = aggregateWebSocketFrameFragments;


// timeouts // timeouts
this.connectTimeout = connectTimeout; this.connectTimeout = connectTimeout;
Expand Down Expand Up @@ -418,6 +421,11 @@ public boolean isValidateResponseHeaders() {
return validateResponseHeaders; return validateResponseHeaders;
} }


@Override
public boolean isAggregateWebSocketFrameFragments() {
return aggregateWebSocketFrameFragments;
}

// ssl // ssl
@Override @Override
public boolean isUseOpenSsl() { public boolean isUseOpenSsl() {
Expand Down Expand Up @@ -617,6 +625,7 @@ public static class Builder {
private boolean useProxySelector = defaultUseProxySelector(); private boolean useProxySelector = defaultUseProxySelector();
private boolean useProxyProperties = defaultUseProxyProperties(); private boolean useProxyProperties = defaultUseProxyProperties();
private boolean validateResponseHeaders = defaultValidateResponseHeaders(); private boolean validateResponseHeaders = defaultValidateResponseHeaders();
private boolean aggregateWebSocketFrameFragments = defaultAggregateWebSocketFrameFragments();


// timeouts // timeouts
private int connectTimeout = defaultConnectTimeout(); private int connectTimeout = defaultConnectTimeout();
Expand Down Expand Up @@ -819,6 +828,11 @@ public Builder setValidateResponseHeaders(boolean validateResponseHeaders) {
return this; return this;
} }


public Builder setAggregateWebSocketFrameFragments(boolean aggregateWebSocketFrameFragments) {
this.aggregateWebSocketFrameFragments = aggregateWebSocketFrameFragments;
return this;
}

public Builder setProxyServer(ProxyServer proxyServer) { public Builder setProxyServer(ProxyServer proxyServer) {
this.proxyServerSelector = uri -> proxyServer; this.proxyServerSelector = uri -> proxyServer;
return this; return this;
Expand Down Expand Up @@ -1123,6 +1137,7 @@ public DefaultAsyncHttpClientConfig build() {
keepEncodingHeader, // keepEncodingHeader, //
resolveProxyServerSelector(), // resolveProxyServerSelector(), //
validateResponseHeaders, // validateResponseHeaders, //
aggregateWebSocketFrameFragments, //
connectTimeout, // connectTimeout, //
requestTimeout, // requestTimeout, //
readTimeout, // readTimeout, //
Expand Down
Expand Up @@ -98,6 +98,10 @@ public static boolean defaultValidateResponseHeaders() {
return AsyncHttpClientConfigHelper.getAsyncHttpClientConfig().getBoolean(ASYNC_CLIENT_CONFIG_ROOT + "validateResponseHeaders"); return AsyncHttpClientConfigHelper.getAsyncHttpClientConfig().getBoolean(ASYNC_CLIENT_CONFIG_ROOT + "validateResponseHeaders");
} }


public static boolean defaultAggregateWebSocketFrameFragments() {
return AsyncHttpClientConfigHelper.getAsyncHttpClientConfig().getBoolean(ASYNC_CLIENT_CONFIG_ROOT + "aggregateWebSocketFrameFragments");
}

public static boolean defaultStrict302Handling() { public static boolean defaultStrict302Handling() {
return AsyncHttpClientConfigHelper.getAsyncHttpClientConfig().getBoolean(ASYNC_CLIENT_CONFIG_ROOT + "strict302Handling"); return AsyncHttpClientConfigHelper.getAsyncHttpClientConfig().getBoolean(ASYNC_CLIENT_CONFIG_ROOT + "strict302Handling");
} }
Expand Down
Expand Up @@ -396,7 +396,9 @@ public Bootstrap getBootstrap(Uri uri, ProxyServer proxy) {
public void upgradePipelineForWebSockets(ChannelPipeline pipeline) { public void upgradePipelineForWebSockets(ChannelPipeline pipeline) {
pipeline.addAfter(HTTP_CLIENT_CODEC, WS_ENCODER_HANDLER, new WebSocket08FrameEncoder(true)); pipeline.addAfter(HTTP_CLIENT_CODEC, WS_ENCODER_HANDLER, new WebSocket08FrameEncoder(true));
pipeline.addBefore(AHC_WS_HANDLER, WS_DECODER_HANDLER, new WebSocket08FrameDecoder(false, false, config.getWebSocketMaxFrameSize())); pipeline.addBefore(AHC_WS_HANDLER, WS_DECODER_HANDLER, new WebSocket08FrameDecoder(false, false, config.getWebSocketMaxFrameSize()));
pipeline.addAfter(WS_DECODER_HANDLER, WS_FRAME_AGGREGATOR, new WebSocketFrameAggregator(config.getWebSocketMaxBufferSize())); if (config.isAggregateWebSocketFrameFragments()) {
pipeline.addAfter(WS_DECODER_HANDLER, WS_FRAME_AGGREGATOR, new WebSocketFrameAggregator(config.getWebSocketMaxBufferSize()));
}
pipeline.remove(HTTP_CLIENT_CODEC); pipeline.remove(HTTP_CLIENT_CODEC);
} }


Expand Down
Expand Up @@ -91,6 +91,14 @@ private void abort(Channel channel, NettyResponseFuture<?> future, WebSocketUpgr
} }
} }


private static WebSocketUpgradeHandler getWebSocketUpgradeHandler(NettyResponseFuture<?> future) {
return (WebSocketUpgradeHandler) future.getAsyncHandler();
}

private static NettyWebSocket getNettyWebSocket(NettyResponseFuture<?> future) throws Exception {
return getWebSocketUpgradeHandler(future).onCompleted();
}

@Override @Override
public void handleRead(Channel channel, NettyResponseFuture<?> future, Object e) throws Exception { public void handleRead(Channel channel, NettyResponseFuture<?> future, Object e) throws Exception {


Expand All @@ -101,7 +109,7 @@ public void handleRead(Channel channel, NettyResponseFuture<?> future, Object e)
logger.debug("\n\nRequest {}\n\nResponse {}\n", httpRequest, response); logger.debug("\n\nRequest {}\n\nResponse {}\n", httpRequest, response);
} }


WebSocketUpgradeHandler handler = (WebSocketUpgradeHandler) future.getAsyncHandler(); WebSocketUpgradeHandler handler = getWebSocketUpgradeHandler(future);
HttpResponseStatus status = new NettyResponseStatus(future.getUri(), response, channel); HttpResponseStatus status = new NettyResponseStatus(future.getUri(), response, channel);
HttpResponseHeaders responseHeaders = new HttpResponseHeaders(response.headers()); HttpResponseHeaders responseHeaders = new HttpResponseHeaders(response.headers());


Expand All @@ -116,9 +124,8 @@ public void handleRead(Channel channel, NettyResponseFuture<?> future, Object e)
} }


} else if (e instanceof WebSocketFrame) { } else if (e instanceof WebSocketFrame) {
final WebSocketFrame frame = (WebSocketFrame) e; WebSocketFrame frame = (WebSocketFrame) e;
WebSocketUpgradeHandler handler = (WebSocketUpgradeHandler) future.getAsyncHandler(); NettyWebSocket webSocket = getNettyWebSocket(future);
NettyWebSocket webSocket = handler.onCompleted();
// retain because we might buffer the frame // retain because we might buffer the frame
if (webSocket.isReady()) { if (webSocket.isReady()) {
webSocket.handleFrame(frame); webSocket.handleFrame(frame);
Expand All @@ -139,11 +146,10 @@ public void handleException(NettyResponseFuture<?> future, Throwable e) {
logger.warn("onError", e); logger.warn("onError", e);


try { try {
WebSocketUpgradeHandler h = (WebSocketUpgradeHandler) future.getAsyncHandler(); NettyWebSocket webSocket = getNettyWebSocket(future);
NettyWebSocket webSocket = h.onCompleted();
if (webSocket != null) { if (webSocket != null) {
webSocket.onError(e.getCause()); webSocket.onError(e.getCause());
webSocket.close(); webSocket.sendCloseFrame();
} }
} catch (Throwable t) { } catch (Throwable t) {
logger.error("onError", t); logger.error("onError", t);
Expand All @@ -152,15 +158,13 @@ public void handleException(NettyResponseFuture<?> future, Throwable e) {


@Override @Override
public void handleChannelInactive(NettyResponseFuture<?> future) { public void handleChannelInactive(NettyResponseFuture<?> future) {
logger.trace("onClose"); logger.trace("Connection was closed abnormally (that is, with no close frame being received).");


try { try {
WebSocketUpgradeHandler h = (WebSocketUpgradeHandler) future.getAsyncHandler(); NettyWebSocket webSocket = getNettyWebSocket(future);
NettyWebSocket webSocket = h.onCompleted(); if (webSocket != null) {

webSocket.onClose(1006, "Connection was closed abnormally (that is, with no close frame being received).");
logger.trace("Connection was closed abnormally (that is, with no close frame being received)."); }
if (webSocket != null)
webSocket.close(1006, "Connection was closed abnormally (that is, with no close frame being received).");
} catch (Throwable t) { } catch (Throwable t) {
logger.error("onError", t); logger.error("onError", t);
} }
Expand Down
Expand Up @@ -339,8 +339,7 @@ public <T> void writeRequest(NettyResponseFuture<T> future, Channel channel) {
HttpRequest httpRequest = nettyRequest.getHttpRequest(); HttpRequest httpRequest = nettyRequest.getHttpRequest();
AsyncHandler<T> handler = future.getAsyncHandler(); AsyncHandler<T> handler = future.getAsyncHandler();


// if the channel is dead because it was pooled and the remote // if the channel is dead because it was pooled and the remote server decided to close it,
// server decided to close it,
// we just let it go and the channelInactive do its work // we just let it go and the channelInactive do its work
if (!Channels.isChannelValid(channel)) if (!Channels.isChannelValid(channel))
return; return;
Expand All @@ -366,6 +365,7 @@ public <T> void writeRequest(NettyResponseFuture<T> future, Channel channel) {


// if the request has a body, we want to track progress // if the request has a body, we want to track progress
if (writeBody) { if (writeBody) {
// FIXME does this really work??? the promise is for the request without body!!!
ChannelProgressivePromise promise = channel.newProgressivePromise(); ChannelProgressivePromise promise = channel.newProgressivePromise();
ChannelFuture f = channel.write(httpRequest, promise); ChannelFuture f = channel.write(httpRequest, promise);
f.addListener(new WriteProgressListener(future, true, 0L)); f.addListener(new WriteProgressListener(future, true, 0L));
Expand Down

0 comments on commit 98bef40

Please sign in to comment.