Skip to content

Commit

Permalink
Handle WebSocket messages when they come in the same frame as the Upg…
Browse files Browse the repository at this point in the history
…rade response, close #1095
  • Loading branch information
slandelle committed Feb 16, 2016
1 parent ceba30b commit c3a8920
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 48 deletions.
Expand Up @@ -471,9 +471,9 @@ public Bootstrap getBootstrap(Uri uri, ProxyServer proxy) {


public void upgradePipelineForWebSockets(ChannelPipeline pipeline) { public void upgradePipelineForWebSockets(ChannelPipeline pipeline) {
pipeline.addAfter(HTTP_CLIENT_CODEC, WS_ENCODER_HANDLER, new WebSocket08FrameEncoder(true)); pipeline.addAfter(HTTP_CLIENT_CODEC, WS_ENCODER_HANDLER, new WebSocket08FrameEncoder(true));
pipeline.remove(HTTP_CLIENT_CODEC);
pipeline.addBefore(AHC_WS_HANDLER, WS_DECODER_HANDLER, new WebSocket08FrameDecoder(false, false, config.getWebSocketMaxFrameSize())); pipeline.addBefore(AHC_WS_HANDLER, WS_DECODER_HANDLER, new WebSocket08FrameDecoder(false, false, config.getWebSocketMaxFrameSize()));
pipeline.addAfter(WS_DECODER_HANDLER, WS_FRAME_AGGREGATOR, new WebSocketFrameAggregator(config.getWebSocketMaxBufferSize())); pipeline.addAfter(WS_DECODER_HANDLER, WS_FRAME_AGGREGATOR, new WebSocketFrameAggregator(config.getWebSocketMaxBufferSize()));
pipeline.remove(HTTP_CLIENT_CODEC);
} }


public final Callback newDrainCallback(final NettyResponseFuture<?> future, final Channel channel, final boolean keepAlive, final Object partitionKey) { public final Callback newDrainCallback(final NettyResponseFuture<?> future, final Channel channel, final boolean keepAlive, final Object partitionKey) {
Expand Down
Expand Up @@ -54,18 +54,6 @@ public WebSocketHandler(AsyncHttpClientConfig config,//
super(config, channelManager, requestSender); super(config, channelManager, requestSender);
} }


// We don't need to synchronize as replacing the "ws-decoder" will
// process using the same thread.
private void invokeOnSucces(Channel channel, WebSocketUpgradeHandler h) {
if (!h.touchSuccess()) {
try {
h.onSuccess(new NettyWebSocket(channel, config));
} catch (Exception ex) {
logger.warn("onSuccess unexpected exception", ex);
}
}
}

private class UpgradeCallback extends Callback { private class UpgradeCallback extends Callback {


private final Channel channel; private final Channel channel;
Expand All @@ -84,6 +72,18 @@ public UpgradeCallback(NettyResponseFuture<?> future, Channel channel, HttpRespo
this.responseHeaders = responseHeaders; this.responseHeaders = responseHeaders;
} }


// We don't need to synchronize as replacing the "ws-decoder" will
// process using the same thread.
private void invokeOnSucces(Channel channel, WebSocketUpgradeHandler h) {
if (!h.touchSuccess()) {
try {
h.onSuccess(new NettyWebSocket(channel, config));
} catch (Exception ex) {
logger.warn("onSuccess unexpected exception", ex);
}
}
}

@Override @Override
public void call() throws Exception { public void call() throws Exception {


Expand Down Expand Up @@ -116,14 +116,16 @@ public void call() throws Exception {
requestSender.abort(channel, future, new IOException(String.format("Invalid challenge. Actual: %s. Expected: %s", accept, key))); requestSender.abort(channel, future, new IOException(String.format("Invalid challenge. Actual: %s. Expected: %s", accept, key)));
} }


// set back the future so the protocol gets notified of frames
// removing the HttpClientCodec from the pipeline might trigger a read with a WebSocket message
// if it comes in the same frame as the HTTP Upgrade response
Channels.setAttribute(channel, future);

channelManager.upgradePipelineForWebSockets(channel.pipeline()); channelManager.upgradePipelineForWebSockets(channel.pipeline());


invokeOnSucces(channel, handler); invokeOnSucces(channel, handler);
future.done(); future.done();
// set back the future so the protocol gets notified of frames
Channels.setAttribute(channel, future);
} }

} }


@Override @Override
Expand All @@ -144,43 +146,61 @@ public void handleRead(Channel channel, NettyResponseFuture<?> future, Object e)
Channels.setAttribute(channel, new UpgradeCallback(future, channel, response, handler, status, responseHeaders)); Channels.setAttribute(channel, new UpgradeCallback(future, channel, response, handler, status, responseHeaders));
} }



} else if (e instanceof WebSocketFrame) { } else if (e instanceof WebSocketFrame) {

final WebSocketFrame frame = (WebSocketFrame) e; final WebSocketFrame frame = (WebSocketFrame) e;
WebSocketUpgradeHandler handler = (WebSocketUpgradeHandler) future.getAsyncHandler(); WebSocketUpgradeHandler handler = (WebSocketUpgradeHandler) future.getAsyncHandler();
NettyWebSocket webSocket = (NettyWebSocket) handler.onCompleted(); NettyWebSocket webSocket = (NettyWebSocket) handler.onCompleted();


if (webSocket != null) { if (webSocket != null) {
if (frame instanceof CloseWebSocketFrame) { handleFrame(channel, frame, handler, webSocket);
Channels.setDiscard(channel);
CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame;
webSocket.onClose(closeFrame.statusCode(), closeFrame.reasonText());
} else {
ByteBuf buf = frame.content();
if (buf != null && buf.readableBytes() > 0) {
HttpResponseBodyPart part = config.getResponseBodyPartFactory().newResponseBodyPart(buf, frame.isFinalFragment());
handler.onBodyPartReceived(part);

if (frame instanceof BinaryWebSocketFrame) {
webSocket.onBinaryFragment(part);
} else if (frame instanceof TextWebSocketFrame) {
webSocket.onTextFragment(part);
} else if (frame instanceof PingWebSocketFrame) {
webSocket.onPing(part);
} else if (frame instanceof PongWebSocketFrame) {
webSocket.onPong(part);
}
}
}
} else { } else {
logger.debug("UpgradeHandler returned a null NettyWebSocket"); logger.debug("Frame received but WebSocket is not available yet, buffering frame");
frame.retain();
Runnable bufferedFrame = new Runnable() {
public void run() {
try {
// WebSocket is now not null
NettyWebSocket webSocket = (NettyWebSocket) handler.onCompleted();
handleFrame(channel, frame, handler, webSocket);
} catch (Exception e) {
logger.debug("Failure while handling buffered frame", e);
handler.onFailure(e);
} finally {
frame.release();
}
};
};
handler.bufferFrame(bufferedFrame);
} }
} else { } else {
logger.error("Invalid message {}", e); logger.error("Invalid message {}", e);
} }
} }


private void handleFrame(Channel channel, WebSocketFrame frame, WebSocketUpgradeHandler handler, NettyWebSocket webSocket) throws Exception {
if (frame instanceof CloseWebSocketFrame) {
Channels.setDiscard(channel);
CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame;
webSocket.onClose(closeFrame.statusCode(), closeFrame.reasonText());
} else {
ByteBuf buf = frame.content();
if (buf != null && buf.readableBytes() > 0) {
HttpResponseBodyPart part = config.getResponseBodyPartFactory().newResponseBodyPart(buf, frame.isFinalFragment());
handler.onBodyPartReceived(part);

if (frame instanceof BinaryWebSocketFrame) {
webSocket.onBinaryFragment(part);
} else if (frame instanceof TextWebSocketFrame) {
webSocket.onTextFragment(part);
} else if (frame instanceof PingWebSocketFrame) {
webSocket.onPing(part);
} else if (frame instanceof PongWebSocketFrame) {
webSocket.onPong(part);
}
}
}
}

@Override @Override
public void handleException(NettyResponseFuture<?> future, Throwable e) { public void handleException(NettyResponseFuture<?> future, Throwable e) {
logger.warn("onError {}", e); logger.warn("onError {}", e);
Expand Down
Expand Up @@ -12,7 +12,7 @@
*/ */
package org.asynchttpclient.ws; package org.asynchttpclient.ws;


import static org.asynchttpclient.util.Assertions.*; import static org.asynchttpclient.util.MiscUtils.isNonEmpty;


import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
Expand All @@ -28,16 +28,26 @@
*/ */
public class WebSocketUpgradeHandler implements UpgradeHandler<WebSocket>, AsyncHandler<WebSocket> { public class WebSocketUpgradeHandler implements UpgradeHandler<WebSocket>, AsyncHandler<WebSocket> {


private static final int SWITCHING_PROTOCOLS = io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS.code();

private WebSocket webSocket; private WebSocket webSocket;
private final List<WebSocketListener> listeners; private final List<WebSocketListener> listeners;
private final AtomicBoolean ok = new AtomicBoolean(false); private final AtomicBoolean ok = new AtomicBoolean(false);
private boolean onSuccessCalled; private boolean onSuccessCalled;
private int status; private int status;
private List<Runnable> bufferedFrames;


public WebSocketUpgradeHandler(List<WebSocketListener> listeners) { public WebSocketUpgradeHandler(List<WebSocketListener> listeners) {
this.listeners = listeners; this.listeners = listeners;
} }


public void bufferFrame(Runnable bufferedFrame) {
if (bufferedFrames == null) {
bufferedFrames = new ArrayList<>();
}
bufferedFrames.add(bufferedFrame);
}

/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
Expand Down Expand Up @@ -66,11 +76,7 @@ public final State onBodyPartReceived(HttpResponseBodyPart bodyPart) throws Exce
@Override @Override
public final State onStatusReceived(HttpResponseStatus responseStatus) throws Exception { public final State onStatusReceived(HttpResponseStatus responseStatus) throws Exception {
status = responseStatus.getStatusCode(); status = responseStatus.getStatusCode();
if (responseStatus.getStatusCode() == 101) { return status == SWITCHING_PROTOCOLS ? State.UPGRADE : State.ABORT;
return State.UPGRADE;
} else {
return State.ABORT;
}
} }


/** /**
Expand All @@ -87,15 +93,15 @@ public final State onHeadersReceived(HttpResponseHeaders headers) throws Excepti
@Override @Override
public final WebSocket onCompleted() throws Exception { public final WebSocket onCompleted() throws Exception {


if (status != 101) { if (status != SWITCHING_PROTOCOLS) {
IllegalStateException e = new IllegalStateException("Invalid Status Code " + status); IllegalStateException e = new IllegalStateException("Invalid Status Code " + status);
for (WebSocketListener listener : listeners) { for (WebSocketListener listener : listeners) {
listener.onError(e); listener.onError(e);
} }
throw e; throw e;
} }


return assertNotNull(webSocket, "webSocket"); return webSocket;
} }


/** /**
Expand All @@ -108,6 +114,12 @@ public final void onSuccess(WebSocket webSocket) {
webSocket.addWebSocketListener(listener); webSocket.addWebSocketListener(listener);
listener.onOpen(webSocket); listener.onOpen(webSocket);
} }
if (isNonEmpty(bufferedFrames)) {
for (Runnable bufferedFrame : bufferedFrames) {
bufferedFrame.run();
}
bufferedFrames = null;
}
ok.set(true); ok.set(true);
} }


Expand Down

0 comments on commit c3a8920

Please sign in to comment.