diff --git a/extras/retrofit2/src/main/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCall.java b/extras/retrofit2/src/main/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCall.java index 6ffd3f1a88..e8980fddab 100644 --- a/extras/retrofit2/src/main/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCall.java +++ b/extras/retrofit2/src/main/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCall.java @@ -12,6 +12,7 @@ */ package org.asynchttpclient.extras.retrofit; +import io.netty.handler.codec.http.HttpHeaderNames; import lombok.*; import lombok.extern.slf4j.Slf4j; import okhttp3.*; @@ -249,7 +250,8 @@ private Response toOkhttpResponse(org.asynchttpclient.Response asyncHttpClientRe // body if (asyncHttpClientResponse.hasResponseBody()) { - val contentType = MediaType.parse(asyncHttpClientResponse.getContentType()); + val contentType = asyncHttpClientResponse.getContentType() == null + ? null : MediaType.parse(asyncHttpClientResponse.getContentType()); val okHttpBody = ResponseBody.create(contentType, asyncHttpClientResponse.getResponseBodyAsBytes()); rspBuilder.body(okHttpBody); } @@ -287,6 +289,9 @@ protected org.asynchttpclient.Request createRequest(@NonNull Request request) { // set request body val body = request.body(); if (body != null && body.contentLength() > 0) { + if (body.contentType() != null) { + requestBuilder.setHeader(HttpHeaderNames.CONTENT_TYPE, body.contentType().toString()); + } // write body to buffer val okioBuffer = new Buffer(); body.writeTo(okioBuffer); diff --git a/extras/retrofit2/src/test/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCallTest.java b/extras/retrofit2/src/test/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCallTest.java index 71f07ce2b5..68ef94624c 100644 --- a/extras/retrofit2/src/test/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCallTest.java +++ b/extras/retrofit2/src/test/java/org/asynchttpclient/extras/retrofit/AsyncHttpClientCallTest.java @@ -12,18 +12,23 @@ */ package org.asynchttpclient.extras.retrofit; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.EmptyHttpHeaders; import lombok.val; +import okhttp3.MediaType; import okhttp3.Request; +import okhttp3.RequestBody; import org.asynchttpclient.AsyncCompletionHandler; import org.asynchttpclient.AsyncHttpClient; import org.asynchttpclient.BoundRequestBuilder; import org.asynchttpclient.Response; +import org.mockito.ArgumentCaptor; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; import java.util.concurrent.ExecutionException; @@ -34,8 +39,8 @@ import static org.asynchttpclient.extras.retrofit.AsyncHttpClientCall.runConsumer; import static org.asynchttpclient.extras.retrofit.AsyncHttpClientCall.runConsumers; import static org.mockito.Matchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; public class AsyncHttpClientCallTest { @@ -226,6 +231,98 @@ Object[][] dataProvider5th() { }; } + @Test + public void contentTypeHeaderIsPassedInRequest() throws Exception { + Request request = requestWithBody(); + + ArgumentCaptor capture = ArgumentCaptor.forClass(org.asynchttpclient.Request.class); + AsyncHttpClient client = mock(AsyncHttpClient.class); + + givenResponseIsProduced(client, aResponse()); + + whenRequestIsMade(client, request); + + verify(client).executeRequest(capture.capture(), any()); + + org.asynchttpclient.Request ahcRequest = capture.getValue(); + + assertTrue(ahcRequest.getHeaders().containsValue("accept", "application/vnd.hal+json", true), + "Accept header not found"); + assertEquals(ahcRequest.getHeaders().get("content-type"), "application/json", + "Content-Type header not found"); + } + + @Test + public void contenTypeIsOptionalInResponse() throws Exception { + AsyncHttpClient client = mock(AsyncHttpClient.class); + + givenResponseIsProduced(client, responseWithBody(null, "test")); + + okhttp3.Response response = whenRequestIsMade(client, REQUEST); + + assertEquals(response.code(), 200); + assertEquals(response.header("Server"), "nginx"); + assertEquals(response.body().contentType(), null); + assertEquals(response.body().string(), "test"); + } + + @Test + public void contentTypeIsProperlyParsedIfPresent() throws Exception { + AsyncHttpClient client = mock(AsyncHttpClient.class); + + givenResponseIsProduced(client, responseWithBody("text/plain", "test")); + + okhttp3.Response response = whenRequestIsMade(client, REQUEST); + + assertEquals(response.code(), 200); + assertEquals(response.header("Server"), "nginx"); + assertEquals(response.body().contentType(), MediaType.parse("text/plain")); + assertEquals(response.body().string(), "test"); + + } + + private void givenResponseIsProduced(AsyncHttpClient client, Response response) { + when(client.executeRequest(any(org.asynchttpclient.Request.class), any())).thenAnswer(invocation -> { + AsyncCompletionHandler handler = invocation.getArgumentAt(1, AsyncCompletionHandler.class); + handler.onCompleted(response); + return null; + }); + } + + private okhttp3.Response whenRequestIsMade(AsyncHttpClient client, Request request) throws IOException { + AsyncHttpClientCall call = AsyncHttpClientCall.builder().httpClient(client).request(request).build(); + + return call.execute(); + } + + private Request requestWithBody() { + return new Request.Builder() + .post(RequestBody.create(MediaType.parse("application/json"), "{\"hello\":\"world\"}".getBytes(StandardCharsets.UTF_8))) + .url("http://example.org/resource") + .addHeader("Accept", "application/vnd.hal+json") + .build(); + } + + private Response aResponse() { + Response response = mock(Response.class); + when(response.getStatusCode()).thenReturn(200); + when(response.getStatusText()).thenReturn("OK"); + when(response.hasResponseHeaders()).thenReturn(true); + when(response.getHeaders()).thenReturn(new DefaultHttpHeaders() + .add("Server", "nginx") + ); + when(response.hasResponseBody()).thenReturn(false); + return response; + } + + private Response responseWithBody(String contentType, String content) { + Response response = aResponse(); + when(response.hasResponseBody()).thenReturn(true); + when(response.getContentType()).thenReturn(contentType); + when(response.getResponseBodyAsBytes()).thenReturn(content.getBytes(StandardCharsets.UTF_8)); + return response; + } + private void doThrow(String message) { throw new RuntimeException(message); }