Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply appsec rate limiter on event instead of when request end #7221

Merged
merged 7 commits into from
Jun 26, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,17 @@
import com.datadog.appsec.event.EventDispatcher;
import com.datadog.appsec.event.ReplaceableEventProducerService;
import com.datadog.appsec.gateway.GatewayBridge;
import com.datadog.appsec.gateway.RateLimiter;
import com.datadog.appsec.powerwaf.PowerWAFModule;
import com.datadog.appsec.util.AbortStartupException;
import com.datadog.appsec.util.StandardizedLogging;
import datadog.appsec.api.blocking.Blocking;
import datadog.appsec.api.blocking.BlockingService;
import datadog.communication.ddagent.SharedCommunicationObjects;
import datadog.communication.monitor.Counter;
import datadog.communication.monitor.Monitoring;
import datadog.remoteconfig.ConfigurationPoller;
import datadog.trace.api.Config;
import datadog.trace.api.ProductActivation;
import datadog.trace.api.gateway.SubscriptionService;
import datadog.trace.api.time.SystemTimeSource;
import datadog.trace.bootstrap.ActiveSubsystems;
import datadog.trace.util.Strings;
import java.util.Collections;
Expand Down Expand Up @@ -81,17 +78,14 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s

sco.createRemaining(config);

RateLimiter rateLimiter = getRateLimiter(config, sco.monitoring);

GatewayBridge gatewayBridge =
new GatewayBridge(
gw,
REPLACEABLE_EVENT_PRODUCER,
rateLimiter,
requestSampler,
APP_SEC_CONFIG_SERVICE.getTraceSegmentPostProcessors());

loadModules(eventDispatcher);
loadModules(eventDispatcher, sco.monitoring);

gatewayBridge.init();
RESET_SUBSCRIPTION_SERVICE = gatewayBridge::stop;
Expand All @@ -112,18 +106,6 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s
}
}

private static RateLimiter getRateLimiter(Config config, Monitoring monitoring) {
RateLimiter rateLimiter = null;
int appSecTraceRateLimit = config.getAppSecTraceRateLimit();
if (appSecTraceRateLimit > 0) {
Counter counter = monitoring.newCounter("_dd.java.appsec.rate_limit.dropped_traces");
rateLimiter =
new RateLimiter(
appSecTraceRateLimit, SystemTimeSource.INSTANCE, () -> counter.increment(1));
}
return rateLimiter;
}

public static boolean isActive() {
return ActiveSubsystems.APPSEC_ACTIVE;
}
Expand All @@ -144,11 +126,11 @@ public static void stop() {
APP_SEC_CONFIG_SERVICE.close();
}

private static void loadModules(EventDispatcher eventDispatcher) {
private static void loadModules(EventDispatcher eventDispatcher, Monitoring monitoring) {
EventDispatcher.DataSubscriptionSet dataSubscriptionSet =
new EventDispatcher.DataSubscriptionSet();

final List<AppSecModule> modules = Collections.singletonList(new PowerWAFModule());
final List<AppSecModule> modules = Collections.singletonList(new PowerWAFModule(monitoring));
for (AppSecModule module : modules) {
log.debug("Starting appsec module {}", module.getName());
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -103,6 +104,9 @@ public class AppSecRequestContext implements DataBundle, Closeable {
private boolean pathParamsPublished;
private Map<String, String> apiSchemas;

private AtomicBoolean rateLimited = new AtomicBoolean(false);
private volatile boolean throttled;

// should be guarded by this
private Additive additive;
// set after additive is set
Expand Down Expand Up @@ -486,4 +490,11 @@ boolean commitApiSchemas(TraceSegment traceSegment) {
apiSchemas.forEach(traceSegment::setTagTop);
return true;
}

public boolean isThrottled(RateLimiter rateLimiter) {
if (rateLimiter != null && rateLimited.compareAndSet(false, true)) {
throttled = rateLimiter.isThrottled();
}
return throttled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public class GatewayBridge {

private final SubscriptionService subscriptionService;
private final EventProducerService producerService;
private final RateLimiter rateLimiter;
private final ApiSecurityRequestSampler requestSampler;
private final List<TraceSegmentPostProcessor> traceSegmentPostProcessors;

Expand All @@ -90,12 +89,10 @@ public class GatewayBridge {
public GatewayBridge(
SubscriptionService subscriptionService,
EventProducerService producerService,
RateLimiter rateLimiter,
ApiSecurityRequestSampler requestSampler,
List<TraceSegmentPostProcessor> traceSegmentPostProcessors) {
this.subscriptionService = subscriptionService;
this.producerService = producerService;
this.rateLimiter = rateLimiter;
this.requestSampler = requestSampler;
this.traceSegmentPostProcessors = traceSegmentPostProcessors;
}
Expand Down Expand Up @@ -141,47 +138,42 @@ public void init() {
pp.processTraceSegment(traceSeg, ctx, collectedEvents);
}

if (rateLimiter == null || !rateLimiter.isThrottled()) {
// If detected any events - mark span at appsec.event
if (!collectedEvents.isEmpty()) {
// Keep event related span, because it could be ignored in case of
// reduced datadog sampling rate.
traceSeg.setTagTop(Tags.ASM_KEEP, true);
traceSeg.setTagTop("appsec.event", true);
traceSeg.setTagTop("network.client.ip", ctx.getPeerAddress());

// Reflect client_ip as actor.ip for backward compatibility
Object clientIp = spanInfo.getTags().get(Tags.HTTP_CLIENT_IP);
if (clientIp != null) {
traceSeg.setTagTop("actor.ip", clientIp);
}
// If detected any events - mark span at appsec.event
if (!collectedEvents.isEmpty()) {
// Set asm keep in case that root span was not available when events are detected
traceSeg.setTagTop(Tags.ASM_KEEP, true);
traceSeg.setTagTop("appsec.event", true);
traceSeg.setTagTop("network.client.ip", ctx.getPeerAddress());

// Reflect client_ip as actor.ip for backward compatibility
Object clientIp = spanInfo.getTags().get(Tags.HTTP_CLIENT_IP);
if (clientIp != null) {
traceSeg.setTagTop("actor.ip", clientIp);
}

// Report AppSec events via "_dd.appsec.json" tag
AppSecEventWrapper wrapper = new AppSecEventWrapper(collectedEvents);
traceSeg.setDataTop("appsec", wrapper);

// Report collected request and response headers based on allow list
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
writeResponseHeaders(
traceSeg, RESPONSE_HEADERS_ALLOW_LIST, ctx.getResponseHeaders());

// Report collected stack traces
StackTraceCollection stackTraceCollection = ctx.transferStackTracesCollection();
if (stackTraceCollection != null) {
Object flatStruct = ObjectFlattener.flatten(stackTraceCollection);
if (flatStruct != null) {
traceSeg.setMetaStructTop("_dd.stack", flatStruct);
}
}
// Report AppSec events via "_dd.appsec.json" tag
AppSecEventWrapper wrapper = new AppSecEventWrapper(collectedEvents);
traceSeg.setDataTop("appsec", wrapper);

} else if (hasUserTrackingEvent(traceSeg)) {
// Report all collected request headers on user tracking event
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
} else {
// Report minimum set of collected request headers
writeRequestHeaders(
traceSeg, DEFAULT_REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
// Report collected request and response headers based on allow list
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
writeResponseHeaders(traceSeg, RESPONSE_HEADERS_ALLOW_LIST, ctx.getResponseHeaders());

// Report collected stack traces
StackTraceCollection stackTraceCollection = ctx.transferStackTracesCollection();
if (stackTraceCollection != null) {
Object flatStruct = ObjectFlattener.flatten(stackTraceCollection);
if (flatStruct != null) {
traceSeg.setMetaStructTop("_dd.stack", flatStruct);
}
}
} else if (hasUserTrackingEvent(traceSeg)) {
// Report all collected request headers on user tracking event
jandro996 marked this conversation as resolved.
Show resolved Hide resolved
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
} else {
// Report minimum set of collected request headers
writeRequestHeaders(
traceSeg, DEFAULT_REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
}
// If extracted any Api Schemas - commit them
if (!ctx.commitApiSchemas(traceSeg)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public RateLimiter(int limitPerSec, TimeSource timeSource, ThrottledCallback cb)
this.throttledCb = cb;
}

public final boolean isThrottled() {
public boolean isThrottled() {
long curSec = this.timeSource.getNanoTicks();
long storedState;
long newState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.datadog.appsec.event.data.DataBundle;
import com.datadog.appsec.event.data.KnownAddresses;
import com.datadog.appsec.gateway.AppSecRequestContext;
import com.datadog.appsec.gateway.RateLimiter;
import com.datadog.appsec.report.AppSecEvent;
import com.datadog.appsec.stack_trace.StackTraceEvent;
import com.datadog.appsec.stack_trace.StackTraceEvent.Frame;
Expand All @@ -21,13 +22,17 @@
import com.squareup.moshi.Moshi;
import com.squareup.moshi.Types;
import datadog.appsec.api.blocking.BlockingContentType;
import datadog.communication.monitor.Counter;
import datadog.communication.monitor.Monitoring;
import datadog.trace.api.Config;
import datadog.trace.api.ProductActivation;
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.telemetry.LogCollector;
import datadog.trace.api.telemetry.WafMetricCollector;
import datadog.trace.api.time.SystemTimeSource;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.util.stacktrace.StackWalkerFactory;
import io.sqreen.powerwaf.Additive;
import io.sqreen.powerwaf.Powerwaf;
Expand Down Expand Up @@ -146,9 +151,18 @@ static void createLimitsObject() {
private final PowerWAFInitializationResultReporter initReporter =
new PowerWAFInitializationResultReporter();
private final PowerWAFStatsReporter statsReporter = new PowerWAFStatsReporter();
private final RateLimiter rateLimiter;

private String currentRulesVersion;

public PowerWAFModule() {
this(null);
}

public PowerWAFModule(Monitoring monitoring) {
this.rateLimiter = getRateLimiter(monitoring);
}

@Override
public void config(AppSecModuleConfigurer appSecConfigService)
throws AppSecModuleActivationException {
Expand Down Expand Up @@ -327,6 +341,21 @@ private PowerwafConfig createPowerwafConfig() {
return pwConfig;
}

private static RateLimiter getRateLimiter(Monitoring monitoring) {
if (monitoring == null) {
return null;
}
RateLimiter rateLimiter = null;
int appSecTraceRateLimit = Config.get().getAppSecTraceRateLimit();
if (appSecTraceRateLimit > 0) {
Counter counter = monitoring.newCounter("_dd.java.appsec.rate_limit.dropped_traces");
rateLimiter =
new RateLimiter(
appSecTraceRateLimit, SystemTimeSource.INSTANCE, () -> counter.increment(1));
}
return rateLimiter;
}

@Override
public String getName() {
return "powerwaf";
Expand Down Expand Up @@ -444,7 +473,21 @@ public void onDataAvailable(
}
}
Collection<AppSecEvent> events = buildEvents(resultWithData);
reqCtx.reportEvents(events);

if (!events.isEmpty() && !reqCtx.isThrottled(rateLimiter)) {
AgentSpan activeSpan = AgentTracer.get().activeSpan();
if (activeSpan != null) {
log.debug("Setting force-keep tag on the current span");
// Keep event related span, because it could be ignored in case of
// reduced datadog sampling rate.
activeSpan.getLocalRootSpan().setTag(Tags.ASM_KEEP, true);
ValentinZakharov marked this conversation as resolved.
Show resolved Hide resolved
} else {
// If active span is not available the ASK_KEEP tag will be set in the GatewayBridge
// when the request ends
log.debug("There is no active span available");
}
jandro996 marked this conversation as resolved.
Show resolved Hide resolved
reqCtx.reportEvents(events);
}

if (flow.isBlocking()) {
reqCtx.setBlocked();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.datadog.appsec.report.AppSecEvent
import com.datadog.appsec.util.AbortStartupException
import datadog.communication.ddagent.DDAgentFeaturesDiscovery
import datadog.communication.ddagent.SharedCommunicationObjects
import datadog.communication.monitor.Counter
import datadog.communication.monitor.Monitoring
import datadog.remoteconfig.ConfigurationChangesTypedListener
import datadog.remoteconfig.ConfigurationEndListener
Expand Down Expand Up @@ -83,36 +82,6 @@ class AppSecSystemSpecification extends DDSpecification {
1 * traceSegment.setTagTop('actor.ip', '1.1.1.1')
}

void 'honors appsec.trace.rate.limit'() {
BiFunction<RequestContext, AgentSpan, Flow<Void>> requestEndedCB
RequestContext requestContext = Mock()
TraceSegment traceSegment = Mock()
AppSecRequestContext appSecReqCtx = Mock()
def sco = sharedCommunicationObjects()
Counter throttledCounter = Mock()
IGSpanInfo span = Mock(AgentSpan)

setup:
injectSysConfig('dd.appsec.trace.rate.limit', '5')

when:
AppSecSystem.start(subService, sco)
7.times { requestEndedCB.apply(requestContext, span) }

then:
span.getTags() >> ['http.client_ip':'1.1.1.1']
1 * sco.monitoring.newCounter('_dd.java.appsec.rate_limit.dropped_traces') >> throttledCounter
1 * subService.registerCallback(EVENTS.requestEnded(), _) >> { requestEndedCB = it[1]; null }
7 * requestContext.getData(RequestContextSlot.APPSEC) >> appSecReqCtx
7 * requestContext.traceSegment >> traceSegment
7 * appSecReqCtx.transferCollectedEvents() >> [Stub(AppSecEvent)]
// allow for one extra in case we move to another second and round down the prev count
(5..6) * appSecReqCtx.getRequestHeaders() >> [:]
(5..6) * appSecReqCtx.getResponseHeaders() >> [:]
(5..6) * traceSegment.setDataTop("appsec", _)
(1..2) * throttledCounter.increment(1)
}

void 'throws if the config file is not parseable'() {
setup:
Path path = Files.createTempFile('dd-trace-', '.json')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,24 @@ class AppSecRequestContextSpecification extends DDSpecification {
ctx.additive == null
!additive.online
}

void 'test isThrottled'(){
setup:
def rateLimiter = Mock(RateLimiter)
def appSecRequestContext = new AppSecRequestContext()

when: 'rate limiter is called and throttled is set'
def result = appSecRequestContext.isThrottled(rateLimiter)

then:
1 * rateLimiter.isThrottled() >> true
assert result

when: 'rate limiter is not called more than once per appsec context returns first result'
def result2 = appSecRequestContext.isThrottled(rateLimiter)

then:
0 * rateLimiter.isThrottled()
result == result2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class GatewayBridgeIGRegistrationSpecification extends DDSpecification {
SubscriptionService ig = Mock()
EventDispatcher eventDispatcher = Mock()

GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, null, [])
GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, [])

void 'request_body_start and request_body_done are registered'() {
given:
Expand Down
Loading
Loading