Skip to content

Commit

Permalink
Add request sampling back in http sources
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-alvarez-alvarez committed Mar 13, 2024
1 parent e8da14e commit e542e02
Show file tree
Hide file tree
Showing 114 changed files with 1,665 additions and 611 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ abstract class IastHttpServerTest<SERVER> extends WithHttpServer<SERVER> impleme
if (iastRequestContext) {
TaintedObjects taintedObjects = iastRequestContext.getTaintedObjects()
TAINTED_OBJECTS.offer(new TaintedObjectCollection(taintedObjects))
List<Vulnerability> vulns = iastRequestContext.getVulnerabilityBatch().getVulnerabilities()
List<Vulnerability> vulns = iastRequestContext.getVulnerabilityBatch().getVulnerabilities() ?: []
VULNERABILITIES.offer(vulns)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ public class TaintableEnumeration implements Enumeration<String> {

private static final String CLASS_NAME = TaintableEnumeration.class.getName();

private volatile IastContext context;
private volatile boolean contextFetched;
private final IastContext context;

private final PropagationModule module;

Expand All @@ -25,11 +24,13 @@ public class TaintableEnumeration implements Enumeration<String> {
private final Enumeration<String> delegate;

private TaintableEnumeration(
final IastContext ctx,
@NonNull final Enumeration<String> delegate,
@NonNull final PropagationModule module,
final byte origin,
@Nullable final CharSequence name,
final boolean useValueAsName) {
this.context = ctx;
this.delegate = delegate;
this.module = module;
this.origin = origin;
Expand Down Expand Up @@ -57,21 +58,13 @@ public String nextElement() {
throw e;
}
try {
module.taint(context(), next, origin, name(next));
module.taint(context, next, origin, name(next));
} catch (final Throwable e) {
module.onUnexpectedException("Failed to taint enumeration", e);
}
return next;
}

private IastContext context() {
if (!contextFetched) {
contextFetched = true;
context = IastContext.Provider.get();
}
return context;
}

private CharSequence name(final String value) {
if (name != null) {
return name;
Expand All @@ -84,18 +77,20 @@ private static boolean nonTaintableEnumerationStack(final StackTraceElement elem
}

public static Enumeration<String> wrap(
final IastContext ctx,
@NonNull final Enumeration<String> delegate,
@NonNull final PropagationModule module,
final byte origin,
@Nullable final CharSequence name) {
return new TaintableEnumeration(delegate, module, origin, name, false);
return new TaintableEnumeration(ctx, delegate, module, origin, name, false);
}

public static Enumeration<String> wrap(
final IastContext ctx,
@NonNull final Enumeration<String> delegate,
@NonNull final PropagationModule module,
final byte origin,
boolean useValueAsName) {
return new TaintableEnumeration(delegate, module, origin, null, useValueAsName);
return new TaintableEnumeration(ctx, delegate, module, origin, null, useValueAsName);
}
}
Original file line number Diff line number Diff line change
@@ -1,44 +1,26 @@
package datadog.trace.agent.tooling.iast

import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.gateway.RequestContextSlot

import datadog.trace.api.iast.IastContext
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.SourceTypes
import datadog.trace.api.iast.propagation.PropagationModule
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import datadog.trace.test.util.DDSpecification
import spock.lang.Shared

class TaintableEnumerationTest extends DDSpecification {

@Shared
protected static final AgentTracer.TracerAPI ORIGINAL_TRACER = AgentTracer.get()

protected AgentTracer.TracerAPI tracer = Mock(AgentTracer.TracerAPI)

protected IastContext iastCtx = Mock(IastContext)

protected RequestContext reqCtx = Mock(RequestContext) {
getData(RequestContextSlot.IAST) >> iastCtx
}

protected AgentSpan span = Mock(AgentSpan) {
getRequestContext() >> reqCtx
}
protected IastContext iastCtx = Stub(IastContext)

protected PropagationModule module


void setup() {
AgentTracer.forceRegister(tracer)
module = Mock(PropagationModule)
InstrumentationBridge.registerIastModule(module)
}

void cleanup() {
AgentTracer.forceRegister(ORIGINAL_TRACER)
InstrumentationBridge.clearIastModules()
}

Expand All @@ -47,35 +29,34 @@ class TaintableEnumerationTest extends DDSpecification {
final values = (1..10).collect { "value$it".toString() }
final origin = SourceTypes.REQUEST_PARAMETER_NAME
final name = 'test'
final enumeration = TaintableEnumeration.wrap(Collections.enumeration(values), module, origin, name)
final enumeration = TaintableEnumeration.wrap(iastCtx, Collections.enumeration(values), module, origin, name)

when:
final result = enumeration.collect()

then:
result == values
values.each { 1 * module.taint(_, it, origin, name) }
1 * tracer.activeSpan() >> span // only one access to the active context
values.each { 1 * module.taint(iastCtx, it, origin, name) }
}

void 'underlying enumerated values are tainted with the value as a name'() {
given:
final values = (1..10).collect { "value$it".toString() }
final origin = SourceTypes.REQUEST_PARAMETER_NAME
final enumeration = TaintableEnumeration.wrap(Collections.enumeration(values), module, origin, true)
final enumeration = TaintableEnumeration.wrap(iastCtx, Collections.enumeration(values), module, origin, true)

when:
final result = enumeration.collect()

then:
result == values
values.each { 1 * module.taint(_, it, origin, it) }
values.each { 1 * module.taint(iastCtx, it, origin, it) }
}

void 'taintable enumeration leaves no trace in case of error'() {
given:
final origin = SourceTypes.REQUEST_PARAMETER_NAME
final enumeration = TaintableEnumeration.wrap(new BadEnumeration(), module, origin, true)
final enumeration = TaintableEnumeration.wrap(iastCtx, new BadEnumeration(), module, origin, true)

when:
enumeration.hasMoreElements()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import akka.http.scaladsl.server.Directive;
import akka.http.scaladsl.server.util.Tupler$;
import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.Source;
import datadog.trace.api.iast.SourceTypes;
import datadog.trace.instrumentation.akkahttp.iast.helpers.TaintCookieFunction;
Expand Down Expand Up @@ -54,20 +59,28 @@ public void methodAdvice(MethodTransformer transformer) {
CookieDirectivesInstrumentation.class.getName() + "$TaintOptionalCookieAdvice");
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintCookieAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_COOKIE_VALUE)
static void after(@Advice.Return(readOnly = false) Directive directive) {
directive = directive.tmap(TaintCookieFunction.INSTANCE, Tupler$.MODULE$.forTuple(null));
static void after(
@Advice.Return(readOnly = false) Directive directive,
@ActiveRequestContext RequestContext reqCtx) {
IastContext ctx = IastContext.Provider.get(reqCtx);
directive = directive.tmap(new TaintCookieFunction(ctx), Tupler$.MODULE$.forTuple(null));
}
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintOptionalCookieAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_COOKIE_VALUE)
static void after(@Advice.Return(readOnly = false) Directive directive) {
static void after(
@Advice.Return(readOnly = false) Directive directive,
@ActiveRequestContext RequestContext reqCtx) {
IastContext ctx = IastContext.Provider.get(reqCtx);
directive =
directive.tmap(TaintOptionalCookieFunction.INSTANCE, Tupler$.MODULE$.forTuple(null));
directive.tmap(new TaintOptionalCookieFunction(ctx), Tupler$.MODULE$.forTuple(null));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import akka.http.scaladsl.model.headers.Cookie;
import akka.http.scaladsl.model.headers.HttpCookiePair;
import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Source;
Expand Down Expand Up @@ -50,11 +54,14 @@ public void methodAdvice(MethodTransformer transformer) {
CookieHeaderInstrumentation.class.getName() + "$TaintAllCookiesAdvice");
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintAllCookiesAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_COOKIE_VALUE)
static void after(
@Advice.This HttpHeader cookie, @Advice.Return Seq<HttpCookiePair> cookiePairs) {
@Advice.This HttpHeader cookie,
@Advice.Return Seq<HttpCookiePair> cookiePairs,
@ActiveRequestContext RequestContext reqCtx) {
PropagationModule prop = InstrumentationBridge.PROPAGATION;
if (prop == null || cookiePairs == null || cookiePairs.isEmpty()) {
return;
Expand All @@ -63,7 +70,7 @@ static void after(
return;
}

final IastContext ctx = IastContext.Provider.get();
final IastContext ctx = IastContext.Provider.get(reqCtx);
Iterator<HttpCookiePair> iterator = cookiePairs.iterator();
while (iterator.hasNext()) {
HttpCookiePair pair = iterator.next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
import akka.http.scaladsl.server.directives.BasicDirectives$;
import akka.http.scaladsl.server.util.Tupler$;
import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.Source;
import datadog.trace.api.iast.SourceTypes;
import datadog.trace.instrumentation.akkahttp.iast.helpers.TaintRequestContextFunction;
Expand Down Expand Up @@ -69,28 +73,40 @@ private void instrumentDirective(MethodTransformer transformation, String method
ExtractDirectivesInstrumentation.class.getName() + '$' + advice);
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintUriDirectiveAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_QUERY)
static void after(@Advice.Return(readOnly = false) Directive directive) {
directive = directive.tmap(TaintUriFunction.INSTANCE, Tupler$.MODULE$.forTuple(null));
static void after(
@Advice.Return(readOnly = false) Directive directive,
@ActiveRequestContext datadog.trace.api.gateway.RequestContext reqCtx) {
IastContext ctx = IastContext.Provider.get(reqCtx);
directive = directive.tmap(new TaintUriFunction(ctx), Tupler$.MODULE$.forTuple(null));
}
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintRequestDirectiveAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_BODY)
static void after(@Advice.Return(readOnly = false) Directive directive) {
directive = directive.tmap(TaintRequestFunction.INSTANCE, Tupler$.MODULE$.forTuple(null));
static void after(
@Advice.Return(readOnly = false) Directive directive,
@ActiveRequestContext datadog.trace.api.gateway.RequestContext reqCtx) {
IastContext ctx = IastContext.Provider.get(reqCtx);
directive = directive.tmap(new TaintRequestFunction(ctx), Tupler$.MODULE$.forTuple(null));
}
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintRequestContextDirectiveAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_BODY)
static void after(@Advice.Return(readOnly = false) Directive directive) {
static void after(
@Advice.Return(readOnly = false) Directive directive,
@ActiveRequestContext datadog.trace.api.gateway.RequestContext reqCtx) {
IastContext ctx = IastContext.Provider.get(reqCtx);
directive =
directive.tmap(TaintRequestContextFunction.INSTANCE, Tupler$.MODULE$.forTuple(null));
directive.tmap(new TaintRequestContextFunction(ctx), Tupler$.MODULE$.forTuple(null));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
import akka.http.scaladsl.server.directives.FormFieldDirectives;
import akka.http.scaladsl.server.util.Tupler$;
import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.Source;
import datadog.trace.api.iast.SourceTypes;
import datadog.trace.instrumentation.akkahttp.iast.helpers.TaintSingleParameterFunction;
Expand Down Expand Up @@ -100,30 +105,38 @@ private void transformDirective(
ParameterDirectivesInstrumentation.class.getName() + "$" + adviceClass);
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintSingleFormFieldDirectiveOldScalaAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_PARAMETER_VALUE)
static void after(
@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Directive retval,
@Advice.Argument(1) FormFieldDirectives.FieldMagnet fmag) {
@Advice.Argument(1) FormFieldDirectives.FieldMagnet fmag,
@ActiveRequestContext RequestContext reqCtx) {
try {
IastContext ctx = IastContext.Provider.get(reqCtx);
retval =
retval.tmap(new TaintSingleParameterFunction<>(fmag), Tupler$.MODULE$.forTuple(null));
retval.tmap(
new TaintSingleParameterFunction<>(ctx, fmag), Tupler$.MODULE$.forTuple(null));
} catch (Exception e) {
throw new RuntimeException(e); // propagate so it's logged
}
}
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintSingleFormFieldDirectiveNewScalaAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_PARAMETER_VALUE)
static void after(
@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Directive retval,
@Advice.Argument(0) FormFieldDirectives.FieldMagnet fmag) {
@Advice.Argument(0) FormFieldDirectives.FieldMagnet fmag,
@ActiveRequestContext RequestContext reqCtx) {
try {
IastContext ctx = IastContext.Provider.get(reqCtx);
retval =
retval.tmap(new TaintSingleParameterFunction<>(fmag), Tupler$.MODULE$.forTuple(null));
retval.tmap(
new TaintSingleParameterFunction<>(ctx, fmag), Tupler$.MODULE$.forTuple(null));
} catch (Exception e) {
throw new RuntimeException(e); // propagate so it's logged
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import akka.http.javadsl.model.HttpHeader;
import datadog.trace.agent.tooling.csi.CallSite;
import datadog.trace.api.iast.IastCallSites;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Source;
import datadog.trace.api.iast.SourceTypes;
Expand All @@ -26,7 +27,11 @@ public static String after(@CallSite.This HttpHeader header, @CallSite.Return St
return result;
}
try {
module.taintIfTainted(result, header, SourceTypes.REQUEST_HEADER_NAME, result);
final IastContext ctx = IastContext.Provider.get();
if (ctx == null) {
return result;
}
module.taintIfTainted(ctx, result, header, SourceTypes.REQUEST_HEADER_NAME, result);
} catch (final Throwable e) {
module.onUnexpectedException("onHeaderNames threw", e);
}
Expand Down

0 comments on commit e542e02

Please sign in to comment.