diff --git a/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy b/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy index 34f6a8a1d48..1e2b025e727 100644 --- a/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy +++ b/dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy @@ -1021,12 +1021,14 @@ abstract class HttpClientTest extends VersionedNamingTestBase { { RequestContext rqCtxt, HttpClientRequest req -> if (req.headers?.containsKey('X-AppSec-Test')) { final context = rqCtxt.getData(RequestContextSlot.APPSEC) as Context - context.hasAppSecData = true - activeSpan() - .setTag('downstream.request.url', req.url) - .setTag('downstream.request.method', req.method) - .setTag('downstream.request.headers', JsonOutput.toJson(req.headers)) - .setTag('downstream.request.body', req.body?.text) + if (context != null) { + context.hasAppSecData = true + activeSpan() + .setTag('downstream.request.url', req.url) + .setTag('downstream.request.method', req.method) + .setTag('downstream.request.headers', JsonOutput.toJson(req.headers)) + .setTag('downstream.request.body', req.body?.text) + } } Flow.ResultFlow.empty() @@ -1035,7 +1037,7 @@ abstract class HttpClientTest extends VersionedNamingTestBase { final BiFunction> httpClientResponseCb = { RequestContext rqCtxt, HttpClientResponse res -> final context = rqCtxt.getData(RequestContextSlot.APPSEC) as Context - if (context.hasAppSecData) { + if (context?.hasAppSecData) { activeSpan() .setTag('downstream.response.status', res.status) .setTag('downstream.response.headers', JsonOutput.toJson(res.headers)) diff --git a/dd-java-agent/instrumentation/okhttp/okhttp-2.2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java b/dd-java-agent/instrumentation/okhttp/okhttp-2.2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java index 7e9f675dcfd..7b72fa4e69c 100644 --- a/dd-java-agent/instrumentation/okhttp/okhttp-2.2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java +++ b/dd-java-agent/instrumentation/okhttp/okhttp-2.2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java @@ -34,20 +34,32 @@ import okio.BufferedSource; import okio.Okio; import okio.Sink; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class AppSecInterceptor implements Interceptor { private static final int BODY_PARSING_SIZE_LIMIT = Config.get().getAppSecBodyParsingSizeLimit(); + private static final Logger LOGGER = LoggerFactory.getLogger(AppSecInterceptor.class); + @Override public Response intercept(final Chain chain) throws IOException { - final AgentSpan span = AgentTracer.activeSpan(); - final RequestContext ctx = span.getRequestContext(); - final long requestId = span.getSpanId(); - final boolean sampled = sampleRequest(ctx, requestId); - final Request request = onRequest(span, sampled, chain.request()); - final Response response = chain.proceed(request); - return onResponse(span, sampled, response); + try { + final AgentSpan span = AgentTracer.activeSpan(); + final RequestContext ctx = span == null ? null : span.getRequestContext(); + if (ctx == null) { + return chain.proceed(chain.request()); + } + final long requestId = span.getSpanId(); + final boolean sampled = sampleRequest(ctx, requestId); + final Request request = onRequest(span, sampled, chain.request()); + final Response response = chain.proceed(request); + return onResponse(span, sampled, response); + } catch (final Exception e) { + LOGGER.debug("Failed to intercept request", e); + return chain.proceed(chain.request()); + } } private Request onRequest(final AgentSpan span, final boolean sampled, final Request request) { diff --git a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/AppSecInterceptor.java b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/AppSecInterceptor.java new file mode 100644 index 00000000000..1a336307de4 --- /dev/null +++ b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/AppSecInterceptor.java @@ -0,0 +1,232 @@ +package datadog.trace.instrumentation.okhttp3; + +import static datadog.trace.api.gateway.Events.EVENTS; + +import datadog.appsec.api.blocking.BlockingException; +import datadog.trace.api.Config; +import datadog.trace.api.appsec.HttpClientPayload; +import datadog.trace.api.appsec.HttpClientRequest; +import datadog.trace.api.appsec.HttpClientResponse; +import datadog.trace.api.appsec.MediaType; +import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.CallbackProvider; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import okhttp3.Headers; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okio.BufferedSink; +import okio.BufferedSource; +import okio.Okio; +import okio.Sink; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AppSecInterceptor implements Interceptor { + + private static final int BODY_PARSING_SIZE_LIMIT = Config.get().getAppSecBodyParsingSizeLimit(); + + private static final Logger LOGGER = LoggerFactory.getLogger(AppSecInterceptor.class); + + @Override + public Response intercept(final Chain chain) throws IOException { + try { + final AgentSpan span = AgentTracer.activeSpan(); + final RequestContext ctx = span == null ? null : span.getRequestContext(); + if (ctx == null) { + return chain.proceed(chain.request()); + } + final long requestId = span.getSpanId(); + final boolean sampled = sampleRequest(ctx, requestId); + final Request request = onRequest(span, sampled, chain.request()); + final Response response = chain.proceed(request); + return onResponse(span, sampled, response); + } catch (final Exception e) { + LOGGER.debug("Failed to intercept request", e); + return chain.proceed(chain.request()); + } + } + + private Request onRequest(final AgentSpan span, final boolean sampled, final Request request) { + Request result = request; + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> requestCb = + cbp.getCallback(EVENTS.httpClientRequest()); + if (requestCb == null) { + return request; + } + + final RequestBody requestBody = request.body(); + final RequestContext ctx = span.getRequestContext(); + final long requestId = span.getSpanId(); + final String url = span.getTag(Tags.HTTP_URL).toString(); + final HttpClientRequest clientRequest = + new HttpClientRequest(requestId, url, request.method(), mapHeaders(request.headers())); + if (sampled && requestBody != null) { + // we are going to effectively read all the request body in memory to be analyzed by the WAF, + // we also modify the outbound request accordingly + final MediaType mediaType = contentType(requestBody); + try { + final long contentLength = requestBody.contentLength(); + if (shouldProcessBody(contentLength, mediaType)) { + final byte[] payload = readBody(requestBody, (int) contentLength); + if (payload.length <= BODY_PARSING_SIZE_LIMIT) { + clientRequest.setBody(mediaType, new ByteArrayInputStream(payload)); + } + result = + request + .newBuilder() + .method(request.method(), RequestBody.create(requestBody.contentType(), payload)) + .build(); // update request + } + } catch (IOException e) { + // ignore it and keep the original request + } + } + publish(ctx, clientRequest, requestCb); + return result; + } + + private Response onResponse( + final AgentSpan span, final boolean sampled, final Response response) { + Response result = response; + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> responseCb = + cbp.getCallback(EVENTS.httpClientResponse()); + if (responseCb == null) { + return response; + } + final ResponseBody responseBody = response.body(); + final RequestContext ctx = span.getRequestContext(); + final long requestId = span.getSpanId(); + final HttpClientResponse clientResponse = + new HttpClientResponse(requestId, response.code(), mapHeaders(response.headers())); + if (sampled && responseBody != null) { + // we are going to effectively read all the response body in memory to be analyzed by the WAF, + // we also + // modify the inbound response accordingly + final MediaType mediaType = contentType(responseBody); + try { + final long contentLength = responseBody.contentLength(); + if (shouldProcessBody(contentLength, mediaType)) { + final byte[] payload = readBody(responseBody, (int) contentLength); + if (payload.length <= BODY_PARSING_SIZE_LIMIT) { + clientResponse.setBody(mediaType, new ByteArrayInputStream(payload)); + } + result = + response + .newBuilder() + .body(ResponseBody.create(responseBody.contentType(), payload)) + .build(); + } + } catch (IOException e) { + // ignore it and keep the original response + } + } + + publish(ctx, clientResponse, responseCb); + return result; + } + + private

void publish( + final RequestContext ctx, + final P request, + final BiFunction> callback) { + Flow flow = callback.apply(ctx, request); + Flow.Action action = flow.getAction(); + if (action instanceof Flow.Action.RequestBlockingAction) { + BlockResponseFunction brf = ctx.getBlockResponseFunction(); + if (brf != null) { + Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action; + brf.tryCommitBlockingResponse( + ctx.getTraceSegment(), + rba.getStatusCode(), + rba.getBlockingContentType(), + rba.getExtraHeaders()); + } + throw new BlockingException("Blocked request (for http downstream request)"); + } + } + + private boolean sampleRequest(final RequestContext ctx, final long requestId) { + // Check if the current http request was sampled + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> samplingCb = + cbp.getCallback(EVENTS.httpClientSampling()); + if (samplingCb == null) { + return false; + } + final Flow sampled = samplingCb.apply(ctx, requestId); + return sampled.getResult() != null && sampled.getResult(); + } + + /** + * Ensure we are only consuming payloads we can safely deserialize with a bounded size to prevent + * from OOM + */ + private boolean shouldProcessBody(final long contentLength, final MediaType mediaType) { + if (contentLength <= 0) { + return false; // prevent from copying from unbounded source (just to be safe) + } + if (BODY_PARSING_SIZE_LIMIT <= 0) { + return false; // effectively disabled by configuration + } + if (contentLength > BODY_PARSING_SIZE_LIMIT) { + return false; + } + return mediaType.isDeserializable(); + } + + private byte[] readBody(final RequestBody body, final int contentLength) throws IOException { + final ByteArrayOutputStream buffer = new ByteArrayOutputStream(contentLength); + try (final BufferedSink sink = Okio.buffer(Okio.sink(buffer))) { + body.writeTo(sink); + } + return buffer.toByteArray(); + } + + private byte[] readBody(final ResponseBody body, final int contentLength) throws IOException { + final ByteArrayOutputStream buffer = new ByteArrayOutputStream(contentLength); + try (final BufferedSource source = body.source(); + final Sink sink = Okio.sink(buffer)) { + source.readAll(sink); + } + return buffer.toByteArray(); + } + + private Map> mapHeaders(final Headers headers) { + if (headers == null) { + return Collections.emptyMap(); + } + final Map> result = new HashMap<>(headers.size()); + for (final String name : headers.names()) { + result.put(name, headers.values(name)); + } + return result; + } + + private MediaType contentType(final RequestBody body) { + return MediaType.parse( + body == null || body.contentType() == null ? null : body.contentType().toString()); + } + + private MediaType contentType(final ResponseBody body) { + return MediaType.parse( + body == null || body.contentType() == null ? null : body.contentType().toString()); + } +} diff --git a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttp3Instrumentation.java b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttp3Instrumentation.java index a1e25d85819..1745f02cc3b 100644 --- a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttp3Instrumentation.java +++ b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttp3Instrumentation.java @@ -7,6 +7,7 @@ import com.google.auto.service.AutoService; import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.bootstrap.ActiveSubsystems; import net.bytebuddy.asm.Advice; import okhttp3.Interceptor; import okhttp3.OkHttpClient; @@ -30,6 +31,7 @@ public String[] helperClassNames() { packageName + ".RequestBuilderInjectAdapter", packageName + ".OkHttpClientDecorator", packageName + ".TracingInterceptor", + packageName + ".AppSecInterceptor", }; } @@ -51,6 +53,9 @@ public static void addTracingInterceptor( } final TracingInterceptor interceptor = new TracingInterceptor(); builder.addInterceptor(interceptor); + if (ActiveSubsystems.APPSEC_ACTIVE) { + builder.addInterceptor(new AppSecInterceptor()); + } } } } diff --git a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttpClientDecorator.java b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttpClientDecorator.java index c9a1af82a26..7bbe47db99f 100644 --- a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttpClientDecorator.java +++ b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttpClientDecorator.java @@ -1,5 +1,6 @@ package datadog.trace.instrumentation.okhttp3; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString; import datadog.trace.bootstrap.instrumentation.decorator.HttpClientDecorator; import java.net.URI; @@ -58,4 +59,10 @@ protected String getRequestHeader(Request request, String headerName) { protected String getResponseHeader(Response response, String headerName) { return response.header(headerName); } + + /** Overridden by {@link AppSecInterceptor} */ + @Override + protected void onHttpClientRequest(AgentSpan span, String url) { + // do nothing + } } diff --git a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3AsyncTest.groovy b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3AsyncTest.groovy index fba9089cff1..a156273532e 100644 --- a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3AsyncTest.groovy +++ b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3AsyncTest.groovy @@ -19,7 +19,8 @@ import static java.util.concurrent.TimeUnit.SECONDS abstract class OkHttp3AsyncTest extends OkHttp3Test { @Override int doRequest(String method, URI uri, Map headers, String body, Closure callback) { - def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse("text/plain"), body) : null + final contentType = headers.remove("Content-Type") + def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse(contentType ?: "text/plain"), body) : null def request = new Request.Builder() .url(uri.toURL()) .method(method, reqBody) @@ -33,13 +34,13 @@ abstract class OkHttp3AsyncTest extends OkHttp3Test { client.newCall(request).enqueue(new Callback() { void onResponse(Call call, Response response) { responseRef.set(response) - callback?.call() + callback?.call(response.body().byteStream()) latch.countDown() } void onFailure(Call call, IOException e) { exRef.set(e) - callback?.call() + callback?.call(e) latch.countDown() } }) diff --git a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3Test.groovy b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3Test.groovy index 5f00769a0ac..ffaba2dd49d 100644 --- a/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3Test.groovy +++ b/dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3Test.groovy @@ -22,27 +22,35 @@ abstract class OkHttp3Test extends HttpClientTest { return false } + @Override + boolean testAppSecClientRequest() { + true + } + @Override boolean useStrictTraceWrites() { // TODO fix this by making sure that spans get closed properly return false } - def client = new OkHttpClient.Builder() - .connectTimeout(CONNECT_TIMEOUT_MS, TimeUnit.MILLISECONDS) - .readTimeout(READ_TIMEOUT_MS, TimeUnit.MILLISECONDS) - .writeTimeout(READ_TIMEOUT_MS, TimeUnit.MILLISECONDS) - .build() + OkHttpClient getClient() { + new OkHttpClient.Builder() + .connectTimeout(CONNECT_TIMEOUT_MS, TimeUnit.MILLISECONDS) + .readTimeout(READ_TIMEOUT_MS, TimeUnit.MILLISECONDS) + .writeTimeout(READ_TIMEOUT_MS, TimeUnit.MILLISECONDS) + .build() + } @Override int doRequest(String method, URI uri, Map headers, String body, Closure callback) { - def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse("text/plain"), body) : null + final contentType = headers.remove("Content-Type") + def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse(contentType ?: "text/plain"), body) : null def request = new Request.Builder() .url(uri.toURL()) .method(method, reqBody) .headers(Headers.of(headers)).build() def response = client.newCall(request).execute() - callback?.call() + callback?.call(response.body().byteStream()) return response.code() } diff --git a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java index 1da614b334e..04cfe0db664 100644 --- a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java +++ b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java @@ -278,17 +278,14 @@ public ResponseEntity> apiSecurityResponse( public ResponseEntity apiSecurityHttpClientOkHttp2(final HttpServletRequest request) throws IOException { // create an internal http request to the echo endpoint to validate the http client library - final String url = - ServletUriComponentsBuilder.fromRequestUri(request) - .replacePath("/echo") - .build() - .toUriString(); + final String url = getEchoUrl(request); Request.Builder clientRequest = new Request.Builder().url(url); - if (request.getMethod().equalsIgnoreCase("POST")) { + if (requiresBody(request.getMethod())) { final String contentType = request.getContentType(); final byte[] data = readFully(request.getInputStream()); clientRequest = - clientRequest.post( + clientRequest.method( + request.getMethod(), com.squareup.okhttp.RequestBody.create( com.squareup.okhttp.MediaType.parse(contentType), data)); } else { @@ -310,6 +307,41 @@ public ResponseEntity apiSecurityHttpClientOkHttp2(final HttpServletRequ return ResponseEntity.status(200).body(clientResponse.body().string()); } + @RequestMapping( + value = "/api_security/http_client/okHttp3", + method = {POST, GET, PUT}) + public ResponseEntity apiSecurityHttpClientOkHttp3(final HttpServletRequest request) + throws IOException { + // create an internal http request to the echo endpoint to validate the http client library + final String url = getEchoUrl(request); + okhttp3.Request.Builder clientRequest = new okhttp3.Request.Builder().url(url); + if (requiresBody(request.getMethod())) { + final String contentType = request.getContentType(); + final byte[] data = readFully(request.getInputStream()); + clientRequest = + clientRequest.method( + request.getMethod(), + okhttp3.RequestBody.create(okhttp3.MediaType.parse(contentType), data)); + } else { + clientRequest.method(request.getMethod(), null); + } + final String statusCode = request.getHeader("Status"); + if (statusCode != null) { + clientRequest = clientRequest.header("Status", statusCode); + } + final String witness = request.getHeader("Witness"); + if (witness != null) { + clientRequest = clientRequest.header("Witness", witness); + } + final String echoHeaders = request.getHeader("echo-headers"); + if (echoHeaders != null) { + clientRequest = clientRequest.header("echo-headers", echoHeaders); + } + final okhttp3.Response clientResponse = + new okhttp3.OkHttpClient().newCall(clientRequest.build()).execute(); + return ResponseEntity.status(200).body(clientResponse.body().string()); + } + @RequestMapping( value = "/echo", method = {POST, GET, PUT}) @@ -321,7 +353,7 @@ public ResponseEntity echo(final HttpServletRequest request) throws IOEx if (echoHeaders != null) { response = response.header("echo-headers", echoHeaders); } - if (request.getMethod().equalsIgnoreCase("POST")) { + if (requiresBody(request.getMethod())) { final String contentType = request.getContentType(); final byte[] data = readFully(request.getInputStream()); return response.contentType(MediaType.parseMediaType(contentType)).body(new String(data)); @@ -330,6 +362,17 @@ public ResponseEntity echo(final HttpServletRequest request) throws IOEx } } + private static boolean requiresBody(final String method) { + return method.equalsIgnoreCase("POST") || method.equalsIgnoreCase("PUT"); + } + + private static String getEchoUrl(final HttpServletRequest request) { + return ServletUriComponentsBuilder.fromRequestUri(request) + .replacePath("/echo") + .build() + .toUriString(); + } + private static byte[] readFully(final InputStream in) throws IOException { ByteArrayOutputStream buffer = new ByteArrayOutputStream(); byte[] data = new byte[4096]; // 4KB buffer diff --git a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy index 19836e926e8..eeab0b16519 100644 --- a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy +++ b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/SpringBootSmokeTest.groovy @@ -1028,7 +1028,7 @@ class SpringBootSmokeTest extends AbstractAppSecServerSmokeTest { } private static List httpClientDownstreamAnalysisVariants() { - return ['okHttp2'] + return ['okHttp2', 'okHttp3'] } private static byte[] unzip(final String text) {