Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT;
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES;
import static datadog.trace.util.Strings.isBlank;
import static java.util.Collections.singletonMap;

import com.squareup.moshi.JsonAdapter;
import com.squareup.moshi.JsonReader;
Expand Down Expand Up @@ -69,7 +68,8 @@ public BadConfigurationException(final String message) {
static final String REASON_TAG = "ai_guard.reason";
static final String BLOCKED_TAG = "ai_guard.blocked";
static final String META_STRUCT_TAG = "ai_guard";
static final String META_STRUCT_KEY = "messages";
static final String META_STRUCT_MESSAGES = "messages";
static final String META_STRUCT_CATEGORIES = "attack_categories";

public static void install() {
final Config config = Config.get();
Expand Down Expand Up @@ -208,8 +208,8 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
} else {
span.setTag(TARGET_TAG, "prompt");
}
final Map<String, Object> metaStruct =
singletonMap(META_STRUCT_KEY, messagesForMetaStruct(messages));
final Map<String, Object> metaStruct = new HashMap<>(2);
metaStruct.put(META_STRUCT_MESSAGES, messagesForMetaStruct(messages));
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
final Request.Builder request =
new Request.Builder()
Expand All @@ -224,14 +224,21 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
}
final Action action = Action.valueOf(actionStr);
final String reason = (String) result.get("reason");
@SuppressWarnings("unchecked")
final List<String> tags = (List<String>) result.get("tags");
span.setTag(ACTION_TAG, action);
span.setTag(REASON_TAG, reason);
if (reason != null) {
span.setTag(REASON_TAG, reason);
}
if (tags != null && !tags.isEmpty()) {
metaStruct.put(META_STRUCT_CATEGORIES, tags);
}
final boolean shouldBlock =
isBlockingEnabled(options, result.get("is_blocking_enabled")) && action != Action.ALLOW;
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
if (shouldBlock) {
span.setTag(BLOCKED_TAG, true);
throw new AIGuardAbortError(action, reason);
throw new AIGuardAbortError(action, reason, tags);
}
return new Evaluation(action, reason);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ class AIGuardInternalTests extends DDSpecification {
Request request = null
Throwable error = null
AIGuard.Evaluation eval = null
Map<String, Object> receivedMeta = null
final throwAbortError = suite.blocking && suite.action != ALLOW
final call = Mock(Call) {
execute() >> {
return mockResponse(
request,
200,
[data: [attributes: [action: suite.action, reason: suite.reason, is_blocking_enabled: suite.blocking]]]
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]]
)
}
}
Expand All @@ -189,11 +190,18 @@ class AIGuardInternalTests extends DDSpecification {
}
1 * span.setTag(AIGuardInternal.ACTION_TAG, suite.action)
1 * span.setTag(AIGuardInternal.REASON_TAG, suite.reason)
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, [messages: suite.messages])
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> {
receivedMeta = it[1] as Map<String, Object>
return span
}
if (throwAbortError) {
1 * span.addThrowable(_ as AIGuard.AIGuardAbortError)
}

receivedMeta.messages == suite.messages
if (suite.tags) {
receivedMeta.attack_categories == suite.tags
}
assertRequest(request, suite.messages)
if (throwAbortError) {
error instanceof AIGuard.AIGuardAbortError
Expand Down Expand Up @@ -444,6 +452,14 @@ class AIGuardInternalTests extends DDSpecification {
0 * span.setTag(AIGuardInternal.TOOL_TAG, _)
}

void 'map requires even number of params'() {
when:
AIGuardInternal.mapOf('1', '2', '3')

then:
thrown(IllegalArgumentException)
}

private static assertTelemetry(final String metric, final String...tags) {
final metrics = WafMetricCollector.get().with {
prepareMetrics()
Expand Down Expand Up @@ -497,22 +513,28 @@ class AIGuardInternalTests extends DDSpecification {
private static class TestSuite {
private final AIGuard.Action action
private final String reason
private final List<String> tags
private final boolean blocking
private final String description
private final String target
private final List<AIGuard.Message> messages

TestSuite(AIGuard.Action action, String reason, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
TestSuite(AIGuard.Action action, String reason, List<String> tags, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
this.action = action
this.reason = reason
this.tags = tags
this.blocking = blocking
this.description = description
this.target = target
this.messages = messages
}

static List<TestSuite> build() {
def actionValues = [[ALLOW, 'Go ahead'], [DENY, 'Nope'], [ABORT, 'Kill it with fire']]
def actionValues = [
[ALLOW, 'Go ahead', []],
[DENY, 'Nope', ['deny_everything', 'test_deny']],
[ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']]
]
def blockingValues = [true, false]
def suiteValues = [
['tool call', 'tool', TOOL_CALL],
Expand All @@ -521,7 +543,7 @@ class AIGuardInternalTests extends DDSpecification {
]
return combinations([actionValues, blockingValues, suiteValues] as Iterable)
.collect { action, blocking, suite ->
new TestSuite(action[0], action[1], blocking, suite[0], suite[1], suite[2])
new TestSuite(action[0], action[1], action[2], blocking, suite[0], suite[1], suite[2])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ public static Evaluation evaluate(final List<Message> messages, final Options op
public static class AIGuardAbortError extends RuntimeException {
private final Action action;
private final String reason;
private final List<String> tags;

public AIGuardAbortError(final Action action, final String reason) {
public AIGuardAbortError(final Action action, final String reason, final List<String> tags) {
super(reason);
this.action = action;
this.reason = reason;
this.tags = tags;
}

public Action getAction() {
Expand All @@ -77,6 +79,10 @@ public Action getAction() {
public String getReason() {
return reason;
}

public List<String> getTags() {
return tags;
}
}

/**
Expand Down