Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply appsec rate limiter on event instead of when request end #7221

Merged
merged 7 commits into from
Jun 26, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class AppSecSystem {
private static AppSecConfigServiceImpl APP_SEC_CONFIG_SERVICE;
private static ReplaceableEventProducerService REPLACEABLE_EVENT_PRODUCER; // testing
private static Runnable RESET_SUBSCRIPTION_SERVICE;
private static RateLimiter RATE_LIMITER; // static for testing purpose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any particular reason to keep RateLimiter in AppSecSystem?
I'd rather move it into PowerWAFModule 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the RateLimiter to PowerWafModule


public static void start(SubscriptionService gw, SharedCommunicationObjects sco) {
try {
Expand Down Expand Up @@ -81,13 +82,12 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s

sco.createRemaining(config);

RateLimiter rateLimiter = getRateLimiter(config, sco.monitoring);
RATE_LIMITER = getRateLimiter(config, sco.monitoring);

GatewayBridge gatewayBridge =
new GatewayBridge(
gw,
REPLACEABLE_EVENT_PRODUCER,
rateLimiter,
requestSampler,
APP_SEC_CONFIG_SERVICE.getTraceSegmentPostProcessors());

Expand Down Expand Up @@ -148,7 +148,7 @@ private static void loadModules(EventDispatcher eventDispatcher) {
EventDispatcher.DataSubscriptionSet dataSubscriptionSet =
new EventDispatcher.DataSubscriptionSet();

final List<AppSecModule> modules = Collections.singletonList(new PowerWAFModule());
final List<AppSecModule> modules = Collections.singletonList(new PowerWAFModule(RATE_LIMITER));
for (AppSecModule module : modules) {
log.debug("Starting appsec module {}", module.getName());
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -103,6 +104,9 @@ public class AppSecRequestContext implements DataBundle, Closeable {
private boolean pathParamsPublished;
private Map<String, String> apiSchemas;

private AtomicBoolean rateLimited = new AtomicBoolean(false);
private volatile boolean throttled;

// should be guarded by this
private Additive additive;
// set after additive is set
Expand Down Expand Up @@ -486,4 +490,11 @@ boolean commitApiSchemas(TraceSegment traceSegment) {
apiSchemas.forEach(traceSegment::setTagTop);
return true;
}

public boolean isThrottled(RateLimiter rateLimiter) {
if (rateLimiter != null && rateLimited.compareAndSet(false, true)) {
throttled = rateLimiter.isThrottled();
}
return throttled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public class GatewayBridge {

private final SubscriptionService subscriptionService;
private final EventProducerService producerService;
private final RateLimiter rateLimiter;
private final ApiSecurityRequestSampler requestSampler;
private final List<TraceSegmentPostProcessor> traceSegmentPostProcessors;

Expand All @@ -90,12 +89,10 @@ public class GatewayBridge {
public GatewayBridge(
SubscriptionService subscriptionService,
EventProducerService producerService,
RateLimiter rateLimiter,
ApiSecurityRequestSampler requestSampler,
List<TraceSegmentPostProcessor> traceSegmentPostProcessors) {
this.subscriptionService = subscriptionService;
this.producerService = producerService;
this.rateLimiter = rateLimiter;
this.requestSampler = requestSampler;
this.traceSegmentPostProcessors = traceSegmentPostProcessors;
}
Expand Down Expand Up @@ -141,47 +138,42 @@ public void init() {
pp.processTraceSegment(traceSeg, ctx, collectedEvents);
}

if (rateLimiter == null || !rateLimiter.isThrottled()) {
// If detected any events - mark span at appsec.event
if (!collectedEvents.isEmpty()) {
// Keep event related span, because it could be ignored in case of
// reduced datadog sampling rate.
traceSeg.setTagTop(Tags.ASM_KEEP, true);
traceSeg.setTagTop("appsec.event", true);
traceSeg.setTagTop("network.client.ip", ctx.getPeerAddress());

// Reflect client_ip as actor.ip for backward compatibility
Object clientIp = spanInfo.getTags().get(Tags.HTTP_CLIENT_IP);
if (clientIp != null) {
traceSeg.setTagTop("actor.ip", clientIp);
}
// If detected any events - mark span at appsec.event
if (!collectedEvents.isEmpty()) {
// Set asm keep in case that root span was not available when events are detected
traceSeg.setTagTop(Tags.ASM_KEEP, true);
traceSeg.setTagTop("appsec.event", true);
traceSeg.setTagTop("network.client.ip", ctx.getPeerAddress());

// Reflect client_ip as actor.ip for backward compatibility
Object clientIp = spanInfo.getTags().get(Tags.HTTP_CLIENT_IP);
if (clientIp != null) {
traceSeg.setTagTop("actor.ip", clientIp);
}

// Report AppSec events via "_dd.appsec.json" tag
AppSecEventWrapper wrapper = new AppSecEventWrapper(collectedEvents);
traceSeg.setDataTop("appsec", wrapper);

// Report collected request and response headers based on allow list
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
writeResponseHeaders(
traceSeg, RESPONSE_HEADERS_ALLOW_LIST, ctx.getResponseHeaders());

// Report collected stack traces
StackTraceCollection stackTraceCollection = ctx.transferStackTracesCollection();
if (stackTraceCollection != null) {
Object flatStruct = ObjectFlattener.flatten(stackTraceCollection);
if (flatStruct != null) {
traceSeg.setMetaStructTop("_dd.stack", flatStruct);
}
}
// Report AppSec events via "_dd.appsec.json" tag
AppSecEventWrapper wrapper = new AppSecEventWrapper(collectedEvents);
traceSeg.setDataTop("appsec", wrapper);

} else if (hasUserTrackingEvent(traceSeg)) {
// Report all collected request headers on user tracking event
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
} else {
// Report minimum set of collected request headers
writeRequestHeaders(
traceSeg, DEFAULT_REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
// Report collected request and response headers based on allow list
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
writeResponseHeaders(traceSeg, RESPONSE_HEADERS_ALLOW_LIST, ctx.getResponseHeaders());

// Report collected stack traces
StackTraceCollection stackTraceCollection = ctx.transferStackTracesCollection();
if (stackTraceCollection != null) {
Object flatStruct = ObjectFlattener.flatten(stackTraceCollection);
if (flatStruct != null) {
traceSeg.setMetaStructTop("_dd.stack", flatStruct);
}
}
} else if (hasUserTrackingEvent(traceSeg)) {
// Report all collected request headers on user tracking event
jandro996 marked this conversation as resolved.
Show resolved Hide resolved
writeRequestHeaders(traceSeg, REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
} else {
// Report minimum set of collected request headers
writeRequestHeaders(
traceSeg, DEFAULT_REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders());
}
// If extracted any Api Schemas - commit them
if (!ctx.commitApiSchemas(traceSeg)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public RateLimiter(int limitPerSec, TimeSource timeSource, ThrottledCallback cb)
this.throttledCb = cb;
}

public final boolean isThrottled() {
public boolean isThrottled() {
long curSec = this.timeSource.getNanoTicks();
long storedState;
long newState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.datadog.appsec.event.data.DataBundle;
import com.datadog.appsec.event.data.KnownAddresses;
import com.datadog.appsec.gateway.AppSecRequestContext;
import com.datadog.appsec.gateway.RateLimiter;
import com.datadog.appsec.report.AppSecEvent;
import com.datadog.appsec.stack_trace.StackTraceEvent;
import com.datadog.appsec.stack_trace.StackTraceEvent.Frame;
Expand All @@ -28,6 +29,7 @@
import datadog.trace.api.telemetry.WafMetricCollector;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.util.stacktrace.StackWalkerFactory;
import io.sqreen.powerwaf.Additive;
import io.sqreen.powerwaf.Powerwaf;
Expand Down Expand Up @@ -146,9 +148,18 @@ static void createLimitsObject() {
private final PowerWAFInitializationResultReporter initReporter =
new PowerWAFInitializationResultReporter();
private final PowerWAFStatsReporter statsReporter = new PowerWAFStatsReporter();
private final RateLimiter rateLimiter;

private String currentRulesVersion;

public PowerWAFModule() {
this(null);
}

public PowerWAFModule(RateLimiter rateLimiter) {
this.rateLimiter = rateLimiter;
}

@Override
public void config(AppSecModuleConfigurer appSecConfigService)
throws AppSecModuleActivationException {
Expand Down Expand Up @@ -444,7 +455,23 @@ public void onDataAvailable(
}
}
Collection<AppSecEvent> events = buildEvents(resultWithData);
reqCtx.reportEvents(events);

if (!events.isEmpty() && !reqCtx.isThrottled(rateLimiter)) {
AgentSpan activeSpan = AgentTracer.get().activeSpan();
if (activeSpan != null) {
if (log.isDebugEnabled()) {
log.debug("Setting force-keep tag on the current span");
}
// Keep event related span, because it could be ignored in case of
// reduced datadog sampling rate.
activeSpan.getLocalRootSpan().setTag(Tags.ASM_KEEP, true);
ValentinZakharov marked this conversation as resolved.
Show resolved Hide resolved
} else {
if (log.isDebugEnabled()) {
jandro996 marked this conversation as resolved.
Show resolved Hide resolved
log.debug("There is no active span available");
}
}
jandro996 marked this conversation as resolved.
Show resolved Hide resolved
reqCtx.reportEvents(events);
}

if (flow.isBlocking()) {
reqCtx.setBlocked();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.datadog.appsec.report.AppSecEvent
import com.datadog.appsec.util.AbortStartupException
import datadog.communication.ddagent.DDAgentFeaturesDiscovery
import datadog.communication.ddagent.SharedCommunicationObjects
import datadog.communication.monitor.Counter
import datadog.communication.monitor.Monitoring
import datadog.remoteconfig.ConfigurationChangesTypedListener
import datadog.remoteconfig.ConfigurationEndListener
Expand Down Expand Up @@ -84,33 +83,17 @@ class AppSecSystemSpecification extends DDSpecification {
}

void 'honors appsec.trace.rate.limit'() {
BiFunction<RequestContext, AgentSpan, Flow<Void>> requestEndedCB
RequestContext requestContext = Mock()
TraceSegment traceSegment = Mock()
AppSecRequestContext appSecReqCtx = Mock()
def sco = sharedCommunicationObjects()
Counter throttledCounter = Mock()
IGSpanInfo span = Mock(AgentSpan)

setup:
injectSysConfig('dd.appsec.trace.rate.limit', '5')
def sco = sharedCommunicationObjects()

when:
AppSecSystem.start(subService, sco)
7.times { requestEndedCB.apply(requestContext, span) }

then:
span.getTags() >> ['http.client_ip':'1.1.1.1']
1 * sco.monitoring.newCounter('_dd.java.appsec.rate_limit.dropped_traces') >> throttledCounter
1 * subService.registerCallback(EVENTS.requestEnded(), _) >> { requestEndedCB = it[1]; null }
7 * requestContext.getData(RequestContextSlot.APPSEC) >> appSecReqCtx
7 * requestContext.traceSegment >> traceSegment
7 * appSecReqCtx.transferCollectedEvents() >> [Stub(AppSecEvent)]
// allow for one extra in case we move to another second and round down the prev count
(5..6) * appSecReqCtx.getRequestHeaders() >> [:]
(5..6) * appSecReqCtx.getResponseHeaders() >> [:]
(5..6) * traceSegment.setDataTop("appsec", _)
(1..2) * throttledCounter.increment(1)
AppSecSystem.RATE_LIMITER.limitPerSec == 5

}

void 'throws if the config file is not parseable'() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,24 @@ class AppSecRequestContextSpecification extends DDSpecification {
ctx.additive == null
!additive.online
}

void 'test isThrottled'(){
setup:
def rateLimiter = Mock(RateLimiter)
def appSecRequestContext = new AppSecRequestContext()

when: 'rate limiter is called and throttled is set'
def result = appSecRequestContext.isThrottled(rateLimiter)

then:
1 * rateLimiter.isThrottled() >> true
assert result

when: 'rate limiter is not called more than once per appsec context returns first result'
def result2 = appSecRequestContext.isThrottled(rateLimiter)

then:
0 * rateLimiter.isThrottled()
result == result2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class GatewayBridgeIGRegistrationSpecification extends DDSpecification {
SubscriptionService ig = Mock()
EventDispatcher eventDispatcher = Mock()

GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, null, [])
GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, [])

void 'request_body_start and request_body_done are registered'() {
given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import datadog.trace.api.gateway.RequestContextSlot
import datadog.trace.api.gateway.SubscriptionService
import datadog.trace.api.http.StoredBodySupplier
import datadog.trace.api.internal.TraceSegment
import datadog.trace.api.time.TimeSource
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase
Expand Down Expand Up @@ -59,9 +58,8 @@ class GatewayBridgeSpecification extends DDSpecification {
i
}()

RateLimiter rateLimiter = new RateLimiter(10, { -> 0L } as TimeSource, RateLimiter.ThrottledCallback.NOOP)
TraceSegmentPostProcessor pp = Mock()
GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, rateLimiter, null, [pp])
GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, [pp])

Supplier<Flow<AppSecRequestContext>> requestStartedCB
BiFunction<RequestContext, AgentSpan, Flow<Void>> requestEndedCB
Expand Down Expand Up @@ -139,7 +137,6 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * mockAppSecCtx.transferCollectedEvents() >> [event]
1 * mockAppSecCtx.peerAddress >> '2001::1'
1 * mockAppSecCtx.close()
1 * traceSegment.setTagTop('asm.keep', true)
1 * traceSegment.setTagTop("_dd.appsec.enabled", 1)
1 * traceSegment.setTagTop("_dd.runtime_family", "jvm")
1 * traceSegment.setTagTop('appsec.event', true)
Expand All @@ -152,27 +149,6 @@ class GatewayBridgeSpecification extends DDSpecification {
flow.action == Flow.Action.Noop.INSTANCE
}

void 'event publishing is rate limited'() {
AppSecEvent event = Stub()
AppSecRequestContext mockAppSecCtx = Mock(AppSecRequestContext)
mockAppSecCtx.requestHeaders >> [:]
RequestContext mockCtx = Stub(RequestContext) {
getData(RequestContextSlot.APPSEC) >> mockAppSecCtx
getTraceSegment() >> traceSegment
}
IGSpanInfo spanInfo = Mock(AgentSpan)

when:
11.times {requestEndedCB.apply(mockCtx, spanInfo) }

then:
11 * mockAppSecCtx.transferCollectedEvents() >> [event]
11 * mockAppSecCtx.close()
11 * mockAppSecCtx.closeAdditive()
10 * spanInfo.getTags() >> ['http.client_ip':'1.1.1.1']
10 * traceSegment.setDataTop("appsec", _)
}

void 'actor ip calculated from headers'() {
AppSecRequestContext mockAppSecCtx = Mock(AppSecRequestContext)
mockAppSecCtx.requestHeaders >> [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,16 @@ class RateLimiterSpecification extends Specification {
2 * mock.nanoTicks >> initialTime - 1_000_000_000L
count == 11
}

void 'test NoOP'(){
setup:
RateLimiter rateLimiter = new RateLimiter(10, { -> 0L } as TimeSource, RateLimiter.ThrottledCallback.NOOP)
def count = 0

when:
15.times {rateLimiter.isThrottled() || count++ }

then:
count == 10
}
}
Loading
Loading