Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drain Reactor Netty Response Body on Close #23855

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
import reactor.util.retry.Retry;

import javax.net.ssl.SSLException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;

import static com.azure.core.http.netty.implementation.Utility.closeConnection;
Expand All @@ -51,6 +48,7 @@
* @see NettyAsyncHttpClientBuilder
*/
class NettyAsyncHttpClient implements HttpClient {
private static final String AZURE_EAGERLY_READ_RESPONSE = "azure-eagerly-read-response";
private static final String AZURE_RESPONSE_TIMEOUT = "azure-response-timeout";

final boolean disableBufferCopy;
Expand All @@ -67,7 +65,7 @@ class NettyAsyncHttpClient implements HttpClient {
* @param disableBufferCopy Determines whether deep cloning of response buffers should be disabled.
*/
NettyAsyncHttpClient(reactor.netty.http.client.HttpClient nettyClient, boolean disableBufferCopy,
long readTimeout, long writeTimeout, long responseTimeout) {
long readTimeout, long writeTimeout, long responseTimeout) {
this.nettyClient = nettyClient;
this.disableBufferCopy = disableBufferCopy;
this.readTimeout = readTimeout;
Expand All @@ -89,23 +87,21 @@ public Mono<HttpResponse> send(HttpRequest request, Context context) {
Objects.requireNonNull(request.getUrl(), "'request.getUrl()' cannot be null.");
Objects.requireNonNull(request.getUrl().getProtocol(), "'request.getUrl().getProtocol()' cannot be null.");

boolean eagerlyReadResponse = (boolean) context.getData("azure-eagerly-read-response").orElse(false);

Optional<Object> requestResponseTimeout = context.getData(AZURE_RESPONSE_TIMEOUT);
long effectiveResponseTimeout = requestResponseTimeout
.map(timeoutDuration -> ((Duration) timeoutDuration).toMillis())
.orElse(this.responseTimeout);
boolean effectiveEagerlyReadResponse = (boolean) context.getData(AZURE_EAGERLY_READ_RESPONSE).orElse(false);
long effectiveResponseTimeout = context.getData(AZURE_RESPONSE_TIMEOUT)
.filter(timeoutDuration -> timeoutDuration instanceof Duration)
.map(timeoutDuration -> ((Duration) timeoutDuration).toMillis())
.orElse(this.responseTimeout);

return nettyClient
.doOnRequest((r, connection) -> addWriteTimeoutHandler(connection, writeTimeout))
.doAfterRequest((r, connection) ->
addResponseTimeoutHandler(connection, effectiveResponseTimeout))
.doAfterRequest((r, connection) -> addResponseTimeoutHandler(connection, effectiveResponseTimeout))
.doOnResponse((response, connection) -> addReadTimeoutHandler(connection, readTimeout))
.doAfterResponseSuccess((response, connection) -> removeReadTimeoutHandler(connection))
.request(HttpMethod.valueOf(request.getHttpMethod().toString()))
.uri(request.getUrl().toString())
.send(bodySendDelegate(request))
.responseConnection(responseDelegate(request, disableBufferCopy, eagerlyReadResponse))
.responseConnection(responseDelegate(request, disableBufferCopy, effectiveEagerlyReadResponse))
.single()
.onErrorMap(throwable -> {
// The exception was an SSLException that was caused by a failure to connect to a proxy.
Expand Down Expand Up @@ -144,14 +140,15 @@ private static BiFunction<HttpClientRequest, NettyOutbound, Publisher<Void>> bod
// adding a header twice that isn't allowed, such as User-Agent, check against the initial request
// header names. If our request header already exists in the Netty request we overwrite it initially
// then append our additional values if it is a multi-value header.
final AtomicBoolean first = new AtomicBoolean(true);
hdr.getValuesList().forEach(value -> {
if (first.compareAndSet(true, false)) {
boolean first = true;
for (String value : hdr.getValuesList()) {
if (first) {
first = false;
reactorNettyRequest.header(hdr.getName(), value);
} else {
reactorNettyRequest.addHeader(hdr.getName(), value);
}
});
}
} else {
hdr.getValuesList().forEach(value -> reactorNettyRequest.addHeader(hdr.getName(), value));
}
Expand All @@ -177,18 +174,16 @@ private static BiFunction<HttpClientResponse, Connection, Publisher<HttpResponse
final HttpRequest restRequest, final boolean disableBufferCopy, final boolean eagerlyReadResponse) {
return (reactorNettyResponse, reactorNettyConnection) -> {
/*
* If we are eagerly reading the response into memory we can ignore the disable buffer copy flag as we
* MUST deeply copy the buffer to ensure it can safely be used downstream.
* If the response is being eagerly read into memory the flag for buffer copying can be ignored as the
* response MUST be deeply copied to ensure it can safely be used downstream.
*/
if (eagerlyReadResponse) {
// Setup the body flux and dispose the connection once it has been received.
Flux<ByteBuffer> body = reactorNettyConnection.inbound().receive().asByteBuffer()
.doFinally(ignored -> closeConnection(reactorNettyConnection));

return FluxUtil.collectBytesFromNetworkResponse(body,
// Set up the body flux and dispose the connection once it has been received.
return FluxUtil.collectBytesFromNetworkResponse(
reactorNettyConnection.inbound().receive().asByteBuffer(),
new NettyToAzureCoreHttpHeadersWrapper(reactorNettyResponse.responseHeaders()))
.doFinally(ignored -> closeConnection(reactorNettyConnection))
.map(bytes -> new NettyAsyncHttpBufferedResponse(reactorNettyResponse, restRequest, bytes));

} else {
return Mono.just(new NettyAsyncHttpResponse(reactorNettyResponse, reactorNettyConnection, restRequest,
disableBufferCopy));
Expand All @@ -197,32 +192,32 @@ private static BiFunction<HttpClientResponse, Connection, Publisher<HttpResponse
}

/*
* Adds the write timeout handler once the request is ready to begin sending.
* Adds write timeout handler once the request is ready to begin sending.
*/
private static void addWriteTimeoutHandler(Connection connection, long timeoutMillis) {
connection.addHandlerLast(WriteTimeoutHandler.HANDLER_NAME, new WriteTimeoutHandler(timeoutMillis));
}

/*
* First removes the write timeout handler from the connection as the request has finished sending, then adds the
* response timeout handler.
* Remove write timeout handler from the connection as the request has finished sending, then add response timeout
* handler.
*/
private static void addResponseTimeoutHandler(Connection connection, long timeoutMillis) {
connection.removeHandler(WriteTimeoutHandler.HANDLER_NAME)
.addHandlerLast(ResponseTimeoutHandler.HANDLER_NAME, new ResponseTimeoutHandler(timeoutMillis));
.addHandlerLast(ResponseTimeoutHandler.HANDLER_NAME, new ResponseTimeoutHandler(timeoutMillis));
}

/*
* First removes the response timeout handler from the connection as the response has been received, then adds the
* read timeout handler.
* Remove response timeout handler from the connection as the response has been received, then add read timeout
* handler.
*/
private static void addReadTimeoutHandler(Connection connection, long timeoutMillis) {
connection.removeHandler(ResponseTimeoutHandler.HANDLER_NAME)
.addHandlerLast(ReadTimeoutHandler.HANDLER_NAME, new ReadTimeoutHandler(timeoutMillis));
.addHandlerLast(ReadTimeoutHandler.HANDLER_NAME, new ReadTimeoutHandler(timeoutMillis));
}

/*
* Removes the read timeout handler as the complete response has been received.
* Remove read timeout handler as the complete response has been received.
*/
private static void removeReadTimeoutHandler(Connection connection) {
connection.removeHandler(ReadTimeoutHandler.HANDLER_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.netty.buffer.ByteBuf;
import reactor.netty.Connection;
import reactor.netty.channel.ChannelOperations;

import java.nio.ByteBuffer;

Expand Down Expand Up @@ -34,7 +35,20 @@ public static ByteBuffer deepCopyBuffer(ByteBuf byteBuf) {
* @param reactorNettyConnection The connection to close.
*/
public static void closeConnection(Connection reactorNettyConnection) {
if (!reactorNettyConnection.isDisposed()) {
// ChannelOperations is generally the default implementation of Connection used.
//
// Using the specific subclass allows for a finer grain handling.
if (reactorNettyConnection instanceof ChannelOperations) {
ChannelOperations<?, ?> channelOperations = (ChannelOperations<?, ?>) reactorNettyConnection;

// Given that this is an HttpResponse the only time this will be called is when the outbound has completed.
//
// From there the only thing that needs to be checked is whether the inbound has been disposed (completed),
// and if not dispose it (aka drain it).
if (!channelOperations.isInboundDisposed()) {
channelOperations.channel().eventLoop().execute(channelOperations::discard);
}
} else if (!reactorNettyConnection.isDisposed()) {
reactorNettyConnection.channel().eventLoop().execute(reactorNettyConnection::dispose);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -255,21 +253,14 @@ public void testConcurrentRequests() {
HttpClient httpClient = new NettyAsyncHttpClientProvider().createInstance();

int numberOfRequests = 100; // 100 = 100MB of data
byte[] expectedDigest = digest();

Mono<Long> numberOfBytesMono;
numberOfBytesMono = Flux.range(1, numberOfRequests)
Mono<Long> numberOfBytesMono = Flux.range(1, numberOfRequests)
.parallel(25)
.runOn(Schedulers.boundedElastic())
.flatMap(ignored -> httpClient.send(new HttpRequest(HttpMethod.GET, url(server, LONG_BODY_PATH)))
.flatMapMany(response -> {
MessageDigest md = md5Digest();
return response.getBody()
.doOnNext(buffer -> md.update(buffer.duplicate()))
.doOnComplete(() -> assertArrayEquals(expectedDigest, md.digest()));
}))
.flatMap(HttpResponse::getBodyAsByteArray)
.doOnNext(bytes -> assertArrayEquals(LONG_BODY, bytes)))
.sequential()
.map(ByteBuffer::remaining)
.map(bytes -> (long) bytes.length)
.reduce(0L, Long::sum);

StepVerifier.create(numberOfBytesMono)
Expand All @@ -278,20 +269,6 @@ public void testConcurrentRequests() {
.verify(Duration.ofSeconds(60));
}

private static MessageDigest md5Digest() {
try {
return MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}

private static byte[] digest() {
MessageDigest md = md5Digest();
md.update(NettyAsyncHttpClientTests.LONG_BODY);
return md.digest();
}

/**
* Tests that deep copying the buffers returned by Netty will make the stream returned to the customer resilient to
* Netty reclaiming them once the 'onNext' operator chain has completed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,6 @@ public void getBodyAsStringWithCharset() {
.verifyComplete();
}

@Test
public void close() {
Connection connection = mock(Connection.class);
when(connection.isDisposed()).thenReturn(true);

new NettyAsyncHttpResponse(null, connection, REQUEST, false).close();

verify(connection, times(1)).isDisposed();
}

@ParameterizedTest
@MethodSource("verifyDisposalSupplier")
public void verifyDisposal(String methodName, Class<?>[] argumentTypes, Object[] argumentValues)
Expand All @@ -193,7 +183,6 @@ public void verifyDisposal(String methodName, Class<?>[] argumentTypes, Object[]

Connection connection = mock(Connection.class);
when(connection.inbound()).thenReturn(nettyInbound);
when(connection.isDisposed()).thenReturn(false);
when(connection.channel()).thenReturn(channel);

NettyAsyncHttpResponse response = new NettyAsyncHttpResponse(reactorNettyResponse, connection, REQUEST,
Expand All @@ -206,7 +195,6 @@ public void verifyDisposal(String methodName, Class<?>[] argumentTypes, Object[]
((Flux<?>) object).blockLast();
}

verify(connection, times(1)).isDisposed();
verify(eventLoop, times(1)).execute(any());
}

Expand Down
Loading