Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

# @DataDog/asm-java (AppSec/IAST)
/buildSrc/call-site-instrumentation-plugin/ @DataDog/asm-java
/dd-java-agent/agent-aiguard/ @DataDog/asm-java
/dd-java-agent/agent-iast/ @DataDog/asm-java
/dd-java-agent/appsec/appsec-test-fixtures/ @DataDog/asm-java
/dd-java-agent/instrumentation/*iast* @DataDog/asm-java
Expand All @@ -58,6 +59,7 @@
/dd-smoke-tests/spring-security/ @DataDog/asm-java
/dd-java-agent/instrumentation/commons-fileupload/ @DataDog/asm-java
/dd-java-agent/instrumentation/spring/spring-security/ @DataDog/asm-java
/dd-trace-api/src/main/java/datadog/trace/api/aiguard/ @DataDog/asm-java
/dd-trace-api/src/main/java/datadog/trace/api/EventTracker.java @DataDog/asm-java
/internal-api/src/main/java/datadog/trace/api/gateway/ @DataDog/asm-java
**/appsec/ @DataDog/asm-java
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
package datadog.communication.serialization;

import datadog.communication.serialization.custom.aiguard.FunctionWriter;
import datadog.communication.serialization.custom.aiguard.MessageWriter;
import datadog.communication.serialization.custom.aiguard.ToolCallWriter;
import datadog.communication.serialization.custom.stacktrace.StackTraceEventFrameWriter;
import datadog.communication.serialization.custom.stacktrace.StackTraceEventWriter;
import datadog.trace.api.Config;
import datadog.trace.api.aiguard.AIGuard;
import datadog.trace.util.stacktrace.StackTraceEvent;
import datadog.trace.util.stacktrace.StackTraceFrame;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public final class Codec extends ClassValue<ValueWriter<?>> {

private static final Map<Class<?>, ValueWriter<?>> defaultConfig =
Stream.of(
new Object[][] {
{StackTraceEvent.class, new StackTraceEventWriter()},
{StackTraceFrame.class, new StackTraceEventFrameWriter()},
})
.collect(Collectors.toMap(data -> (Class<?>) data[0], data -> (ValueWriter<?>) data[1]));
public static final Codec INSTANCE;

public static final Codec INSTANCE = new Codec(defaultConfig);
static {
final Map<Class<?>, ValueWriter<?>> writers = new HashMap<>(1 << 3);
writers.put(StackTraceEvent.class, new StackTraceEventWriter());
writers.put(StackTraceFrame.class, new StackTraceEventFrameWriter());
if (Config.get().isAiGuardEnabled()) {
writers.put(AIGuard.Message.class, new MessageWriter());
writers.put(AIGuard.ToolCall.class, new ToolCallWriter());
writers.put(AIGuard.ToolCall.Function.class, new FunctionWriter());
}
INSTANCE = new Codec(writers);
}

private final Map<Class<?>, ValueWriter<?>> config;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package datadog.communication.serialization.custom.aiguard;

import datadog.communication.serialization.EncodingCache;
import datadog.communication.serialization.ValueWriter;
import datadog.communication.serialization.Writable;
import datadog.trace.api.aiguard.AIGuard;

public class FunctionWriter implements ValueWriter<AIGuard.ToolCall.Function> {

@Override
public void write(
final AIGuard.ToolCall.Function function,
final Writable writable,
final EncodingCache encodingCache) {
writable.startMap(2);
writable.writeString("name", encodingCache);
writable.writeString(function.getName(), encodingCache);
writable.writeString("arguments", encodingCache);
writable.writeString(function.getArguments(), encodingCache);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package datadog.communication.serialization.custom.aiguard;

import datadog.communication.serialization.EncodingCache;
import datadog.communication.serialization.ValueWriter;
import datadog.communication.serialization.Writable;
import datadog.trace.api.aiguard.AIGuard;
import datadog.trace.util.Strings;
import java.util.List;

public class MessageWriter implements ValueWriter<AIGuard.Message> {

@Override
public void write(
final AIGuard.Message value, final Writable writable, final EncodingCache encodingCache) {
final int[] size = {0};
final boolean hasRole = isNotBlank(value.getRole(), size);
final boolean hasContent = isNotBlank(value.getContent(), size);
final boolean hasToolCallId = isNotBlank(value.getToolCallId(), size);
final boolean hasToolCalls = isNotEmpty(value.getToolCalls(), size);
writable.startMap(size[0]);
writeString(hasRole, "role", value.getRole(), writable, encodingCache);
writeString(hasContent, "content", value.getContent(), writable, encodingCache);
writeString(hasToolCallId, "tool_call_id", value.getToolCallId(), writable, encodingCache);
writeToolCallArray(hasToolCalls, "tool_calls", value.getToolCalls(), writable, encodingCache);
}

private static void writeString(
final boolean present,
final String key,
final String value,
final Writable writable,
final EncodingCache encodingCache) {
if (present) {
writable.writeString(key, encodingCache);
writable.writeString(value, encodingCache);
}
}

private static void writeToolCallArray(
final boolean present,
final String key,
final List<AIGuard.ToolCall> values,
final Writable writable,
final EncodingCache encodingCache) {
if (present) {
writable.writeString(key, encodingCache);
writable.writeObject(values, encodingCache);
}
}

private static boolean isNotBlank(final String value, final int[] nonBlankCount) {
final boolean hasText = Strings.isNotBlank(value);
if (hasText) {
nonBlankCount[0]++;
}
return hasText;
}

private static boolean isNotEmpty(final List<?> value, final int[] nonEmptyCount) {
final boolean nonEmpty = value != null && !value.isEmpty();
if (nonEmpty) {
nonEmptyCount[0]++;
}
return nonEmpty;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package datadog.communication.serialization.custom.aiguard;

import datadog.communication.serialization.EncodingCache;
import datadog.communication.serialization.ValueWriter;
import datadog.communication.serialization.Writable;
import datadog.trace.api.aiguard.AIGuard;

public class ToolCallWriter implements ValueWriter<AIGuard.ToolCall> {

@Override
public void write(
final AIGuard.ToolCall value, final Writable writable, final EncodingCache encodingCache) {
writable.startMap(2);
writable.writeString("id", encodingCache);
writable.writeString(value.getId(), encodingCache);
writable.writeString("function", encodingCache);
writable.writeObject(value.getFunction(), encodingCache);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package datadog.communication.serialization.aiguard

import datadog.communication.serialization.EncodingCache
import datadog.communication.serialization.GrowableBuffer
import datadog.communication.serialization.msgpack.MsgPackWriter
import datadog.trace.api.aiguard.AIGuard
import datadog.trace.test.util.DDSpecification
import org.msgpack.core.MessagePack
import org.msgpack.value.Value

import java.nio.charset.StandardCharsets
import java.util.function.Function

class MessageWriterTest extends DDSpecification {

private EncodingCache encodingCache
private GrowableBuffer buffer
private MsgPackWriter writer

void setup() {
injectSysConfig('ai_guard.enabled', 'true')
final HashMap<CharSequence, byte[]> cache = new HashMap<>()
encodingCache = new EncodingCache() {
@Override
byte[] encode(CharSequence chars) {
cache.computeIfAbsent(chars, s -> s.toString().getBytes(StandardCharsets.UTF_8))
}
}
buffer = new GrowableBuffer(1024)
writer = new MsgPackWriter(buffer)
}

void 'test write message'() {
given:
final message = AIGuard.Message.message('user', 'What day is today?')

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringValueMap(unpacker.unpackValue())
value.size() == 2
value.role == 'user'
value.content == 'What day is today?'
}
}

void 'test write tool call'() {
given:
final message =
AIGuard.Message.assistant(
AIGuard.ToolCall.toolCall('call_1', 'function_1', 'args_1'),
AIGuard.ToolCall.toolCall('call_2', 'function_2', 'args_2'))

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringKeyMap(unpacker.unpackValue())
value.size() == 2
asString(value.role) == 'assistant'

final toolCalls = value.get('tool_calls').asArrayValue().list()
toolCalls.size() == 2

final firstCall = asStringKeyMap(toolCalls[0])
asString(firstCall.id) == 'call_1'
final firstFunction = asStringValueMap(firstCall.function)
firstFunction.name == 'function_1'
firstFunction.arguments == 'args_1'

final secondCall = asStringKeyMap(toolCalls[1])
asString(secondCall.id) == 'call_2'
final secondFunction = asStringValueMap(secondCall.function)
secondFunction.name == 'function_2'
secondFunction.arguments == 'args_2'
}
}

void 'test write tool output'() throws IOException {
given:
final message = AIGuard.Message.tool('call_1', 'output')

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringValueMap(unpacker.unpackValue())
value.size() == 3
value.role == 'tool'
value.tool_call_id == 'call_1'
value.content == 'output'
}
}

private static <K, V> Map<K, V> mapValue(
final Value values,
final Function<Value, K> keyMapper,
final Function<Value, V> valueMapper) {
return values.asMapValue().entrySet().collectEntries {
[(keyMapper.apply(it.key)): valueMapper.apply(it.value)]
}
}

private static Map<String, Value> asStringKeyMap(final Value values) {
return mapValue(values, MessageWriterTest::asString, Function.identity())
}

private static Map<String, String> asStringValueMap(final Value values) {
return mapValue(values, MessageWriterTest::asString, MessageWriterTest::asString)
}

private static String asString(final Value value) {
return value.asStringValue().asString()
}
}
36 changes: 36 additions & 0 deletions dd-java-agent/agent-aiguard/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar

plugins {
id 'com.gradleup.shadow'
}

apply from: "$rootDir/gradle/java.gradle"
apply from: "$rootDir/gradle/version.gradle"

java {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}

dependencies {
api libs.slf4j
implementation libs.moshi
implementation libs.okhttp

api project(':dd-trace-api')
implementation project(':internal-api')
implementation project(':communication')

testImplementation project(':utils:test-utils')
testImplementation('org.skyscreamer:jsonassert:1.5.3')
testImplementation('com.fasterxml.jackson.core:jackson-databind:2.20.0')
}

tasks.named("shadowJar", ShadowJar) {
dependencies deps.excludeShared
}

tasks.named("jar", Jar) {
archiveClassifier = 'unbundled'
}

Loading