Skip to content

Commit

Permalink
Refactor propagation module to include two different APIs for strings…
Browse files Browse the repository at this point in the history
… and objects
  • Loading branch information
manuel-alvarez-alvarez committed Apr 8, 2024
1 parent 734e3c5 commit 565c62e
Show file tree
Hide file tree
Showing 211 changed files with 1,578 additions and 1,397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ public static <E> NamedContext getOrCreate(
}
final PropagationModule module = InstrumentationBridge.PROPAGATION;
if (module != null) {
final Source source = module.findSource(target);
if (source != null) {
result = new NamedContextImpl(module, source);
final IastContext ctx = IastContext.Provider.get();
final Source source;
if (ctx != null && (source = module.findSource(ctx, target)) != null) {
result = new NamedContextImpl(module, ctx, source);
}
}
result = result == null ? NoOp.INSTANCE : result;
Expand All @@ -51,20 +52,22 @@ public void taintName(@Nullable final String name) {}

private static class NamedContextImpl extends NamedContext {
@Nonnull private final PropagationModule module;
@Nonnull private final IastContext ctx;
@Nonnull private final Source source;
@Nullable private String currentName;

private boolean fetched;
@Nullable private IastContext context;

public NamedContextImpl(@Nonnull final PropagationModule module, @Nonnull final Source source) {
public NamedContextImpl(
@Nonnull final PropagationModule module,
@Nonnull final IastContext ctx,
@Nonnull final Source source) {
this.module = module;
this.ctx = ctx;
this.source = source;
}

@Override
public void taintValue(@Nullable final String value) {
module.taint(iastCtx(), value, source.getOrigin(), currentName, source.getValue());
module.taintString(ctx, value, source.getOrigin(), currentName, source.getValue());
}

@Override
Expand All @@ -74,16 +77,8 @@ public void taintName(@Nullable final String name) {
// prevent tainting the same name more than once
if (currentName != name) {
currentName = name;
module.taint(iastCtx(), name, source.getOrigin(), name, source.getValue());
}
}

private IastContext iastCtx() {
if (!fetched) {
fetched = true;
context = IastContext.Provider.get();
module.taintString(ctx, name, source.getOrigin(), name, source.getValue());
}
return context;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,50 @@
package datadog.trace.bootstrap.instrumentation.iast


import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.gateway.RequestContextSlot
import datadog.trace.api.iast.IastContext
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.SourceTypes
import datadog.trace.api.iast.Taintable.Source
import datadog.trace.api.iast.propagation.PropagationModule
import datadog.trace.bootstrap.ContextStore
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import datadog.trace.bootstrap.instrumentation.api.AgentTracer.TracerAPI
import datadog.trace.test.util.DDSpecification
import spock.lang.Shared

class NamedContextTest extends DDSpecification {

@Shared
protected static final TracerAPI ORIGINAL_TRACER = AgentTracer.get()

protected PropagationModule module
protected ContextStore store
protected TracerAPI tracer
protected IastContext ctx

void setup() {
ctx = Stub(IastContext)
final reqCtx = Stub(RequestContext) {
getData(RequestContextSlot.IAST) >> ctx
}
final span = Stub(AgentSpan) {
getRequestContext() >> reqCtx
}
tracer = Stub(TracerAPI) {
activeSpan() >> span
}
AgentTracer.forceRegister(tracer)
module = Mock(PropagationModule)
InstrumentationBridge.registerIastModule(module)
store = Mock(ContextStore)
}

void cleanup() {
AgentTracer.forceRegister(ORIGINAL_TRACER)
}

void 'test that the context taints names and values'() {
given:
final source = new SourceImpl(origin: SourceTypes.REQUEST_PARAMETER_NAME)
Expand All @@ -31,14 +57,14 @@ class NamedContextTest extends DDSpecification {

then:
1 * store.get(target) >> null
1 * module.findSource(target) >> source
1 * module.findSource(ctx, target) >> source
1 * store.put(target, _)

when:
context.taintName(name)

then:
1 * module.taint(_, name, source.origin, name, source.value)
1 * module.taintString(ctx, name, source.origin, name, source.value)

when:
context.taintName(name)
Expand All @@ -50,7 +76,7 @@ class NamedContextTest extends DDSpecification {
context.taintValue(value)

then:
1 * module.taint(_, value, source.origin, name, source.value)
1 * module.taintString(ctx, value, source.origin, name, source.value)
0 * _
}

Expand All @@ -62,7 +88,7 @@ class NamedContextTest extends DDSpecification {
final ctx = NamedContext.getOrCreate(store, target)

then:
1 * module.findSource(target) >> null
1 * module.findSource(ctx, target) >> null
1 * store.put(target, _)

when:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static datadog.trace.api.iast.VulnerabilityMarks.NOT_MARKED;

import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.propagation.CodecModule;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand All @@ -11,7 +12,10 @@ public class FastCodecModule extends PropagationModuleImpl implements CodecModul
@Override
public void onUrlDecode(
@Nonnull final String value, @Nullable final String encoding, @Nonnull final String result) {
taintIfTainted(result, value);
final IastContext ctx = IastContext.Provider.get();
if (ctx != null) {
taintStringIfTainted(ctx, result, value);
}
}

@Override
Expand All @@ -22,22 +26,34 @@ public void onStringFromBytes(
@Nullable final String charset,
@Nonnull final String result) {
// create a new range shifted to the result string coordinates
taintIfTainted(result, value, offset, length, false, NOT_MARKED);
final IastContext ctx = IastContext.Provider.get();
if (ctx != null) {
taintStringIfRangeTainted(ctx, result, value, offset, length, false, NOT_MARKED);
}
}

@Override
public void onStringGetBytes(
@Nonnull final String value, @Nullable final String charset, @Nonnull final byte[] result) {
taintIfTainted(result, value);
final IastContext ctx = IastContext.Provider.get();
if (ctx != null) {
taintObjectIfTainted(ctx, result, value);
}
}

@Override
public void onBase64Encode(@Nullable byte[] value, @Nullable byte[] result) {
taintIfTainted(result, value);
final IastContext ctx = IastContext.Provider.get();
if (ctx != null) {
taintObjectIfTainted(ctx, result, value);
}
}

@Override
public void onBase64Decode(@Nullable byte[] value, @Nullable byte[] result) {
taintIfTainted(result, value);
final IastContext ctx = IastContext.Provider.get();
if (ctx != null) {
taintObjectIfTainted(ctx, result, value);
}
}
}
Loading

0 comments on commit 565c62e

Please sign in to comment.