Skip to content

Commit

Permalink
Remove fast codec module as only one will ever exist (#7302)
Browse files Browse the repository at this point in the history
Remove fast codec module as only one will ever exist
  • Loading branch information
manuel-alvarez-alvarez committed Jul 11, 2024
1 parent 1294ac2 commit e25c56d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import static datadog.trace.api.iast.IastDetectionMode.UNLIMITED;

import com.datadog.iast.overhead.OverheadController;
import com.datadog.iast.propagation.FastCodecModule;
import com.datadog.iast.propagation.CodecModuleImpl;
import com.datadog.iast.propagation.PropagationModuleImpl;
import com.datadog.iast.propagation.StringModuleImpl;
import com.datadog.iast.sink.ApplicationModuleImpl;
Expand Down Expand Up @@ -122,7 +122,7 @@ private static Stream<IastModule> iastModules(
Stream<Class<? extends IastModule>> modules =
Stream.of(
StringModuleImpl.class,
FastCodecModule.class,
CodecModuleImpl.class,
SqlInjectionModuleImpl.class,
PathTraversalModuleImpl.class,
CommandInjectionModuleImpl.class,
Expand Down Expand Up @@ -167,12 +167,18 @@ private static boolean isOptOut(final Class<? extends IastModule> module) {
private static <M extends IastModule> M newIastModule(
final Dependencies dependencies, final Class<M> type) {
try {
final Constructor<M> ctor = (Constructor<M>) type.getDeclaredConstructors()[0];
if (ctor.getParameterCount() == 0) {
return ctor.newInstance();
} else {
return ctor.newInstance(dependencies);
for (final Constructor<?> ctor : type.getDeclaredConstructors()) {
switch (ctor.getParameterCount()) {
case 0:
return (M) ctor.newInstance();
case 1:
if (ctor.getParameterTypes()[0] == Dependencies.class) {
return (M) ctor.newInstance(dependencies);
}
break;
}
}
throw new RuntimeException("Cannot find constructor for the module " + type);
} catch (final Throwable e) {
// should never happen and be caught on IAST tests
throw new UndeclaredThrowableException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@
import static datadog.trace.api.iast.VulnerabilityMarks.NOT_MARKED;

import datadog.trace.api.iast.propagation.CodecModule;
import datadog.trace.api.iast.propagation.PropagationModule;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class FastCodecModule extends PropagationModuleImpl implements CodecModule {
public class CodecModuleImpl implements CodecModule {

private final PropagationModule propagationModule;

public CodecModuleImpl() {
this(new PropagationModuleImpl());
}

CodecModuleImpl(final PropagationModule propagationModule) {
this.propagationModule = propagationModule;
}

@Override
public void onUrlDecode(
@Nonnull final String value, @Nullable final String encoding, @Nonnull final String result) {
taintStringIfTainted(result, value);
propagationModule.taintStringIfTainted(result, value);
}

@Override
Expand All @@ -22,22 +33,22 @@ public void onStringFromBytes(
@Nullable final String charset,
@Nonnull final String result) {
// create a new range shifted to the result string coordinates
taintStringIfRangeTainted(result, value, offset, length, false, NOT_MARKED);
propagationModule.taintStringIfRangeTainted(result, value, offset, length, false, NOT_MARKED);
}

@Override
public void onStringGetBytes(
@Nonnull final String value, @Nullable final String charset, @Nonnull final byte[] result) {
taintObjectIfTainted(result, value);
propagationModule.taintObjectIfTainted(result, value);
}

@Override
public void onBase64Encode(@Nullable byte[] value, @Nullable byte[] result) {
taintObjectIfTainted(result, value);
propagationModule.taintObjectIfTainted(result, value);
}

@Override
public void onBase64Decode(@Nullable byte[] value, @Nullable byte[] result) {
taintObjectIfTainted(result, value);
propagationModule.taintObjectIfTainted(result, value);
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
package com.datadog.iast.propagation

import com.datadog.iast.IastModuleImplTestBase
import com.datadog.iast.taint.TaintedObject
import com.datadog.iast.model.Range
import com.datadog.iast.model.Source
import com.datadog.iast.taint.Ranges
import com.datadog.iast.taint.TaintedObjects
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.VulnerabilityMarks
import datadog.trace.api.iast.propagation.CodecModule
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import groovy.transform.CompileDynamic

import java.nio.charset.StandardCharsets

import static com.datadog.iast.taint.TaintUtils.addFromTaintFormat

@CompileDynamic
abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
class CodecModuleTest extends IastModuleImplTestBase {

protected CodecModule module

def setup() {
module = buildModule()
module = new CodecModuleImpl()
InstrumentationBridge.registerIastModule(module)
}

@Override
Expand Down Expand Up @@ -72,8 +78,14 @@ abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
if (isTainted) {
assert to != null
assert to.get() == result
assert to.ranges.size() == 1

final sourceTainted = taintedObjects.get(parsed)
assertOnUrlDecode(value, encoding, sourceTainted, to)
final sourceRange = Ranges.highestPriorityRange(sourceTainted.ranges)
final range = to.ranges.first()
assert range.start == 0
assert range.length == result.length()
assert range.source == sourceRange.source
} else {
assert to == null
}
Expand Down Expand Up @@ -103,8 +115,14 @@ abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
if (isTainted) {
assert to != null
assert to.get() == result
assert to.ranges.size() == 1

final sourceTainted = taintedObjects.get(parsed)
assertOnStringGetBytes(value, charset, sourceTainted, to)
final sourceRange = Ranges.highestPriorityRange(sourceTainted.ranges)
final range = to.ranges.first()
assert range.start == 0
assert range.length == Integer.MAX_VALUE // unbound for non char sequences
assert range.source == sourceRange.source
} else {
assert to == null
}
Expand Down Expand Up @@ -140,8 +158,14 @@ abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
if (isTainted) {
assert to != null
assert to.get() == result
assert to.ranges.size() == 1

final sourceTainted = taintedObjects.get(parsed)
assertOnStringFromBytes(bytes, 0, bytes.length, charset, sourceTainted, to)
final sourceRange = Ranges.highestPriorityRange(sourceTainted.ranges)
final range = to.ranges.first()
assert range.start == 0
assert range.length == result.length()
assert range.source == sourceRange.source
} else {
assert to == null
}
Expand Down Expand Up @@ -177,8 +201,14 @@ abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
if (isTainted) {
assert to != null
assert to.get() == result
assert to.ranges.length == 1

final sourceTainted = taintedObjects.get(parsed)
assertBase64Decode(parsedBytes, sourceTainted, to)
final sourceRange = Ranges.highestPriorityRange(sourceTainted.ranges)
final range = to.ranges.first()
assert range.start == 0
assert range.length == Integer.MAX_VALUE // unbound for non char sequences
assert range.source == sourceRange.source
} else {
assert to == null
}
Expand Down Expand Up @@ -214,8 +244,14 @@ abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
if (isTainted) {
assert to != null
assert to.get() == result
assert to.ranges.length == 1

final sourceTainted = taintedObjects.get(parsed)
assertBase64Encode(parsedBytes, sourceTainted, to)
final sourceRange = Ranges.highestPriorityRange(sourceTainted.ranges)
final range = to.ranges.first()
assert range.start == 0
assert range.length == Integer.MAX_VALUE // unbound for non char sequences
assert range.source == sourceRange.source
} else {
assert to == null
}
Expand All @@ -232,15 +268,42 @@ abstract class BaseCodecModuleTest extends IastModuleImplTestBase {
'==>Hello<== World!' | _
}

protected abstract void assertOnUrlDecode(final String value, final String encoding, final TaintedObject source, final TaintedObject target)

protected abstract void assertOnStringFromBytes(final byte[] value, final int offset, final int length, final String encoding, final TaintedObject source, final TaintedObject target)
void 'test on string from bytes with multiple ranges'() {
given:
final charset = StandardCharsets.UTF_8
final string = "Hello World!"
final bytes = string.getBytes(charset) // 1 byte pe char
final TaintedObjects to = ctx.taintedObjects
final ranges = [
new Range(0, 5, new Source((byte) 0, 'name1', 'Hello'), VulnerabilityMarks.NOT_MARKED),
new Range(6, 6, new Source((byte) 1, 'name2', 'World!'), VulnerabilityMarks.NOT_MARKED)
]
to.taint(bytes, ranges as Range[])

protected abstract void assertOnStringGetBytes(final String value, final String encoding, final TaintedObject source, final TaintedObject target)
when:
final hello = string.substring(0, 5)
module.onStringFromBytes(bytes, 0, 5, charset.name(), hello)

protected abstract void assertBase64Decode(final byte[] value, final TaintedObject source, final TaintedObject target)
then:
final helloTainted = to.get(hello)
helloTainted.ranges.length == 1
helloTainted.ranges.first().with {
assert it.source.origin == (byte) 0
assert it.source.name == 'name1'
assert it.source.value == 'Hello'
}

protected abstract void assertBase64Encode(final byte[] value, final TaintedObject source, final TaintedObject target)
when:
final world = string.substring(6, 12)
module.onStringFromBytes(bytes, 6, 6, charset.name(), world)

protected abstract CodecModule buildModule()
then:
final worldTainted = to.get(world)
worldTainted.ranges.length == 1
worldTainted.ranges.first().with {
assert it.source.origin == (byte) 1
assert it.source.name == 'name2'
assert it.source.value == 'World!'
}
}
}

This file was deleted.

0 comments on commit e25c56d

Please sign in to comment.