Skip to content

Commit

Permalink
create a new method for ssrfModule and inprove tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jandro996 committed Nov 14, 2023
1 parent 30cda09 commit c5a4757
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package com.datadog.iast.sink;

import static datadog.trace.api.iast.VulnerabilityMarks.NOT_MARKED;

import com.datadog.iast.Dependencies;
import com.datadog.iast.IastRequestContext;
import com.datadog.iast.model.Evidence;
import com.datadog.iast.model.Range;
import com.datadog.iast.model.Source;
import com.datadog.iast.model.VulnerabilityType;
import com.datadog.iast.overhead.Operations;
import com.datadog.iast.taint.Ranges;
import com.datadog.iast.taint.TaintedObject;
import datadog.trace.api.iast.sink.SsrfModule;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
Expand All @@ -26,4 +34,49 @@ public void onURLConnection(@Nullable final Object url) {
}
checkInjection(span, ctx, VulnerabilityType.SSRF, url);
}

/*
* if the host or the uri are tainted, we report the url as tainted as well
* A new range is created covering all the value string in order to simplify the algorithm
*/
@Override
public void onURLConnection(@Nullable String value, @Nullable Object host, @Nullable Object uri) {
if (value == null) {
return;
}
final AgentSpan span = AgentTracer.activeSpan();
final IastRequestContext ctx = IastRequestContext.get(span);
if (ctx == null) {
return;
}
TaintedObject taintedObject = getTaintedObject(ctx, host, uri);
if (taintedObject == null) {
return;
}
Range[] ranges =
Ranges.getNotMarkedRanges(taintedObject.getRanges(), VulnerabilityType.SSRF.mark());
if (ranges == null || ranges.length == 0) {
return;
}
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
return;
}
Source source = Ranges.highestPriorityRange(ranges).getSource();
final Evidence result =
new Evidence(value, new Range[] {new Range(0, value.length(), source, NOT_MARKED)});
report(span, VulnerabilityType.SSRF, result);
}

@Nullable
private TaintedObject getTaintedObject(
final IastRequestContext ctx, @Nullable final Object host, @Nullable final Object uri) {
TaintedObject taintedObject = null;
if (uri != null) {
taintedObject = ctx.getTaintedObjects().get(uri);
}
if (taintedObject == null && host != null) {
taintedObject = ctx.getTaintedObjects().get(host);
}
return taintedObject;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,42 @@ class SsrfModuleTest extends IastModuleImplTestBase {
0 * reporter.report(_, _)
}

void 'test SSRF detection for host and uri'() {
when:
module.onURLConnection(value, host, uri)

then: 'report is not called if no active span'
tracer.activeSpan() >> null
0 * reporter.report(_, _)

when:
module.onURLConnection(value, host, uri)

then: 'report is not called if host or uri are not tainted'
tracer.activeSpan() >> span
0 * reporter.report(_, _)

when:
taint(host!=null ? host : uri)
module.onURLConnection(value, host, uri)

then: 'report is called when the host or uri are tainted'
tracer.activeSpan() >> span
1 * reporter.report(span, {
Vulnerability vul -> vul.type == VulnerabilityType.SSRF
&& vul.evidence.value == value
&& vul.evidence.ranges.length == 1
&& vul.evidence.ranges[0].start == 0
&& vul.evidence.ranges[0].length == value.length()
})


where:
value | host | uri
'http://test.com' | new Object() | new URI('http://test.com/tested')
'http://test.com' | null | new URI('http://test.com/tested')
}

private void taint(final Object value) {
ctx.getTaintedObjects().taint(value, Ranges.forObject(new Source(SourceTypes.REQUEST_PARAMETER_VALUE, 'name', value.toString())))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datadog.trace.api.iast.sink.SsrfModule;
import datadog.trace.bootstrap.CallDepthThreadLocalMap;
import datadog.trace.bootstrap.instrumentation.api.URIUtils;
import java.net.URI;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.client.methods.HttpUriRequest;
Expand All @@ -25,10 +26,11 @@ public static void doMethodEnter(HttpHost httpHost, HttpRequest request) {
if (callDepth > 0) {
return;
}
final SsrfModule module = InstrumentationBridge.SSRF;
if (module != null) {
module.onURLConnection(
URIUtils.safeConcat(httpHost.toURI(), request.getRequestLine().getUri()).toString());
final SsrfModule ssrfModule = InstrumentationBridge.SSRF;
if (ssrfModule != null) {
URI concatedUri = URIUtils.safeConcat(httpHost.toURI(), request.getRequestLine().getUri());
ssrfModule.onURLConnection(
concatedUri.toString(), httpHost.toURI(), request.getRequestLine().getUri());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,32 @@ public String apacheSsrf(
@RequestParam("url") final String url,
@RequestParam("client") final String client,
@RequestParam("method") final String method,
@RequestParam("requestType") final String requestType) {
@RequestParam("requestType") final String requestType,
@RequestParam("scheme") final String scheme) {
try {
HttpClient httpClient = getHttpClient(Client.valueOf(client));
execute(httpClient, url, ExecuteMethod.valueOf(method), requestType);
execute(httpClient, url, ExecuteMethod.valueOf(method), requestType, scheme);
} catch (Exception e) {
}
return "OK";
}

private void execute(
final HttpClient client, final String url, ExecuteMethod executeMethod, String requestType)
final HttpClient client,
final String url,
ExecuteMethod executeMethod,
String requestType,
String scheme)
throws IOException {
HttpUriRequest httpUriRequest = new HttpGet(url);
boolean isUriRequest = requestType.equals(Request.HttpUriRequest.name());
HttpRequest httpRequest = isUriRequest ? httpUriRequest : new BasicHttpRequest("GET", url);
HttpHost host =
new HttpHost(
httpUriRequest.getURI().getHost(),
httpUriRequest.getURI().getPort(),
httpUriRequest.getURI().getScheme());
new HttpHost(httpUriRequest.getURI().getHost(), httpUriRequest.getURI().getPort(), scheme);
HttpRequest httpRequest =
isUriRequest
? httpUriRequest
: new BasicHttpRequest(
"GET", url.startsWith(scheme) ? url : url.substring(host.toURI().length()));
switch (executeMethod) {
case REQUEST:
client.execute(httpUriRequest);
Expand Down Expand Up @@ -117,14 +123,14 @@ public enum Request {
}

private enum ExecuteMethod {
REQUEST(55),
REQUEST_CONTEXT(58),
HOST_REQUEST(61),
REQUEST_HANDLER(64),
REQUEST_HANDLER_CONTEXT(67),
HOST_REQUEST_HANDLER(70),
HOST_REQUEST_HANDLER_CONTEXT(73),
HOST_REQUEST_CONTEXT(76);
REQUEST(56),
REQUEST_CONTEXT(59),
HOST_REQUEST(62),
REQUEST_HANDLER(65),
REQUEST_HANDLER_CONTEXT(68),
HOST_REQUEST_HANDLER(71),
HOST_REQUEST_HANDLER_CONTEXT(74),
HOST_REQUEST_CONTEXT(77);

final Integer expectedLine;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ class SmokeTest extends AbstractIastServerSmokeTest {

void 'ssrf is present'() {
setup:
final expected = 'https://dd.datad0g.com/test'
def expected = 'http://dd.datad0g.com/test'+'/'+suite.executedMethod
final url = "http://localhost:${httpPort}/ssrf/execute"
final body = new FormBody.Builder()
.add('url', expected)
.add('client', suite.clientImplementation)
.add('method', suite.executedMethod)
.add('requestType', suite.requestType)
.add('scheme', suite.scheme)
.build()
final request = new Request.Builder().url(url).post(body).build()
if(suite.scheme == 'https') {
expected = expected.replace('http', 'https')
}

when:
def response = client.newCall(request).execute()
Expand All @@ -67,21 +71,23 @@ class SmokeTest extends AbstractIastServerSmokeTest {
for (SsrfController.Client client : SsrfController.Client.values()) {
for (SsrfController.ExecuteMethod method : SsrfController.ExecuteMethod.values()) {
if (method.name().startsWith('HOST')) {
result.add(createTestSuite(client, method, SsrfController.Request.HttpRequest))
result.add(createTestSuite(client, method, SsrfController.Request.HttpRequest, 'http'))
result.add(createTestSuite(client, method, SsrfController.Request.HttpRequest, 'https'))
}
result.add(createTestSuite(client, method, SsrfController.Request.HttpUriRequest))
result.add(createTestSuite(client, method, SsrfController.Request.HttpUriRequest, 'http'))
}
}
return result as Iterable<TestSuite>
}

private TestSuite createTestSuite(client, method, request) {
private TestSuite createTestSuite(client, method, request, scheme) {
return new TestSuite(
description: "ssrf is present for ${client} client and ${method} method with ${request}",
description: "ssrf is present for ${client} client and ${method} method with ${request} and ${scheme} scheme",
executedMethod: method.name(),
clientImplementation: client.name(),
expectedLine: method.expectedLine,
requestType: request.name()
requestType: request.name(),
scheme: scheme
)
}

Expand All @@ -91,6 +97,7 @@ class SmokeTest extends AbstractIastServerSmokeTest {
String clientImplementation
Integer expectedLine
String requestType
String scheme

@Override
String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
public interface SsrfModule extends IastModule {

void onURLConnection(@Nullable Object url);

void onURLConnection(@Nullable String url, @Nullable Object host, @Nullable Object uri);
}

0 comments on commit c5a4757

Please sign in to comment.