Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import datadog.trace.api.DDTags;
import datadog.trace.api.InstrumenterConfig;
import datadog.trace.api.ProductActivation;
import datadog.trace.api.appsec.HttpClientRequest;
import datadog.trace.api.gateway.BlockResponseFunction;
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.RequestContext;
Expand Down Expand Up @@ -99,7 +100,7 @@ public AgentSpan onRequest(final AgentSpan span, final REQUEST request) {
HTTP_RESOURCE_DECORATOR.withClientPath(span, method, url.getPath());
}
// SSRF exploit prevention check
onNetworkConnection(url.toString());
onHttpClientRequest(span, url.toString());
} else if (shouldSetResourceName()) {
span.setResourceName(DEFAULT_RESOURCE_NAME);
}
Expand Down Expand Up @@ -178,24 +179,19 @@ public long getResponseContentLength(final RESPONSE response) {
return 0;
}

private void onNetworkConnection(final String networkConnection) {
protected void onHttpClientRequest(final AgentSpan span, final String url) {
if (!APPSEC_RASP_ENABLED) {
return;
}
if (networkConnection == null) {
if (url == null) {
return;
}
final BiFunction<RequestContext, String, Flow<Void>> networkConnectionCallback =
final BiFunction<RequestContext, HttpClientRequest, Flow<Void>> requestCb =
AgentTracer.get()
.getCallbackProvider(RequestContextSlot.APPSEC)
.getCallback(EVENTS.networkConnection());
.getCallback(EVENTS.httpClientRequest());

if (networkConnectionCallback == null) {
return;
}

final AgentSpan span = AgentTracer.get().activeSpan();
if (span == null) {
if (requestCb == null) {
return;
}

Expand All @@ -204,7 +200,8 @@ private void onNetworkConnection(final String networkConnection) {
return;
}

Flow<Void> flow = networkConnectionCallback.apply(ctx, networkConnection);
final long requestId = span.getSpanId();
Flow<Void> flow = requestCb.apply(ctx, new HttpClientRequest(requestId, url));
Flow.Action action = flow.getAction();
if (action instanceof Flow.Action.RequestBlockingAction) {
BlockResponseFunction brf = ctx.getBlockResponseFunction();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package datadog.trace.bootstrap.instrumentation.decorator

import datadog.trace.api.DDTags
import datadog.trace.api.appsec.HttpClientRequest
import datadog.trace.api.config.AppSecConfig
import datadog.trace.api.gateway.CallbackProvider
import static datadog.trace.api.gateway.Events.EVENTS
Expand Down Expand Up @@ -249,8 +250,8 @@ class HttpClientDecoratorTest extends ClientDecoratorTest {
decorator.onRequest(span2, req)

then:
1 * callbackProvider.getCallback(EVENTS.networkConnection()) >> listener
1 * listener.apply(reqCtx, _ as String)
1 * callbackProvider.getCallback(EVENTS.httpClientRequest()) >> listener
1 * listener.apply(reqCtx, _ as HttpClientRequest)
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.datadog.appsec.api.security;

import com.datadog.appsec.gateway.AppSecRequestContext;

public interface ApiSecurityDownstreamSampler {

boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId);

boolean isSampled(AppSecRequestContext ctx, long requestId);

class NoOp implements ApiSecurityDownstreamSampler {

public static final NoOp INSTANCE = new NoOp();

@Override
public boolean sampleHttpClientRequest(AppSecRequestContext ctx, long requestId) {
return false;
}

@Override
public boolean isSampled(AppSecRequestContext ctx, long requestId) {
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.datadog.appsec.api.security;

import com.datadog.appsec.gateway.AppSecRequestContext;
import datadog.trace.api.Config;
import java.util.concurrent.atomic.AtomicLong;

public class ApiSecurityDownstreamSamplerImpl implements ApiSecurityDownstreamSampler {

private static final long KNUTH_FACTOR = 1111111111111111111L;
private final AtomicLong globalRequestCount;
private final double threshold;

public ApiSecurityDownstreamSamplerImpl() {
this(Config.get().getApiSecurityDownstreamRequestAnalysisSampleRate());
}

public ApiSecurityDownstreamSamplerImpl(final double rate) {
threshold = samplingCutoff(rate < 0.0 ? 0 : (rate > 1.0 ? 1 : rate));
globalRequestCount = new AtomicLong(0);
}

private static double samplingCutoff(final double rate) {
final double max = Math.pow(2, 64) - 1;
if (rate < 0.5) {
return (long) (rate * max) + Long.MIN_VALUE;
}
if (rate < 1.0) {
return (long) ((rate * max) + Long.MIN_VALUE);
}
return Long.MAX_VALUE;
}

/**
* First sample the request to ensure we randomize the request and then check if the current
* server request has budget to analyze the downstream request.
*/
@Override
public boolean sampleHttpClientRequest(final AppSecRequestContext ctx, final long requestId) {
final long counter = updateRequestCount();
if (counter * KNUTH_FACTOR + Long.MIN_VALUE > threshold) {
return false;
}
return ctx.sampleHttpClientRequest(requestId);
}

@Override
public boolean isSampled(final AppSecRequestContext ctx, final long requestId) {
return ctx.isHttpClientRequestSampled(requestId);
}

private long updateRequestCount() {
return globalRequestCount.updateAndGet(cur -> (cur == Long.MAX_VALUE) ? 0L : cur + 1L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ public interface KnownAddresses {
/** The URL of a network resource being requested (outgoing request) */
Address<String> IO_NET_URL = new Address<>("server.io.net.url");

/** The headers of a network resource being requested (outgoing request) */
Address<Map<String, List<String>>> IO_NET_REQUEST_HEADERS =
new Address<>("server.io.net.request.headers");

/** The method of a network resource being requested (outgoing request) */
Address<String> IO_NET_REQUEST_METHOD = new Address<>("server.io.net.request.method");

/** The body of a network resource being requested (outgoing request) */
Address<Object> IO_NET_REQUEST_BODY = new Address<>("server.io.net.request.body");

/** The status of a network resource being requested (outgoing request) */
Address<String> IO_NET_RESPONSE_STATUS = new Address<>("server.io.net.response.status");

/** The response headers of a network resource being requested (outgoing request) */
Address<Map<String, List<String>>> IO_NET_RESPONSE_HEADERS =
new Address<>("server.io.net.response.headers");

/** The response body of a network resource being requested (outgoing request) */
Address<Object> IO_NET_RESPONSE_BODY = new Address<>("server.io.net.response.body");

/** The representation of opened file on the filesystem */
Address<String> IO_FS_FILE = new Address<>("server.io.fs.file");

Expand Down Expand Up @@ -206,6 +226,18 @@ static Address<?> forName(String name) {
return SESSION_ID;
case "server.io.net.url":
return IO_NET_URL;
case "server.io.net.request.headers":
return IO_NET_REQUEST_HEADERS;
case "server.io.net.request.method":
return IO_NET_REQUEST_METHOD;
case "server.io.net.request.body":
return IO_NET_REQUEST_BODY;
case "server.io.net.response.status":
return IO_NET_RESPONSE_STATUS;
case "server.io.net.response.headers":
return IO_NET_RESPONSE_HEADERS;
case "server.io.net.response.body":
return IO_NET_RESPONSE_BODY;
case "server.io.fs.file":
return IO_FS_FILE;
case "server.db.system":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ public class AppSecRequestContext implements DataBundle, Closeable {
private volatile Long apiSecurityEndpointHash;
private volatile byte keepType = PrioritySampling.SAMPLER_KEEP;

private final AtomicInteger httpClientRequestCount = new AtomicInteger(0);
private final Set<Long> sampledHttpClientRequests = new HashSet<>();

private static final AtomicIntegerFieldUpdater<AppSecRequestContext> WAF_TIMEOUTS_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(AppSecRequestContext.class, "wafTimeouts");
private static final AtomicIntegerFieldUpdater<AppSecRequestContext> RASP_TIMEOUTS_UPDATER =
Expand Down Expand Up @@ -235,6 +238,26 @@ public void increaseRaspTimeouts() {
RASP_TIMEOUTS_UPDATER.incrementAndGet(this);
}

public boolean sampleHttpClientRequest(final long id) {
httpClientRequestCount.incrementAndGet();
synchronized (sampledHttpClientRequests) {
if (sampledHttpClientRequests.size()
< Config.get().getApiSecurityMaxDownstreamRequestBodyAnalysis()) {
sampledHttpClientRequests.add(id);
return true;
}
}
return false;
}

public boolean isHttpClientRequestSampled(final long id) {
return sampledHttpClientRequests.contains(id);
}

public int getHttpClientRequestCount() {
return httpClientRequestCount.get();
}

public int getWafTimeouts() {
return wafTimeouts;
}
Expand Down
Loading