Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ void taskRegisteredWithCorrectTypeAndGroup() {
Project project = ProjectBuilder.builder().build();
project.getPluginManager().apply("com.agenteval.evaluate");

var task = project.getTasks().findByName("agentEvaluate");
assertThat(task).isNotNull();
var task = project.getTasks().getByName("agentEvaluate");
assertThat(task).isInstanceOf(EvaluateTask.class);
assertThat(task.getGroup()).isEqualTo("verification");
}
Expand All @@ -43,8 +42,7 @@ void extensionDefaultValues() {
project.getPluginManager().apply("com.agenteval.evaluate");

AgentEvalExtension ext = project.getExtensions()
.findByType(AgentEvalExtension.class);
assertThat(ext).isNotNull();
.getByType(AgentEvalExtension.class);

assertThat(ext.getConfigFile().get()).isEqualTo("agenteval.yaml");
assertThat(ext.getReportFormats().get()).isEqualTo("console,json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ static void checkEnvironment() {
try {
Project project = ProjectBuilder.builder().build();
project.getPluginManager().apply("com.agenteval.evaluate");
EvaluateTask task = (EvaluateTask) project.getTasks().findByName("agentEvaluate");
EvaluateTask task = (EvaluateTask) project.getTasks().getByName("agentEvaluate");
gradleTaskCreationSupported = task != null;
} catch (Exception e) {
gradleTaskCreationSupported = false;
Expand All @@ -34,8 +34,7 @@ void taskDefaultPropertyValues() {
Project project = ProjectBuilder.builder().build();
project.getPluginManager().apply("com.agenteval.evaluate");

EvaluateTask task = (EvaluateTask) project.getTasks().findByName("agentEvaluate");
assertThat(task).isNotNull();
EvaluateTask task = (EvaluateTask) project.getTasks().getByName("agentEvaluate");

assertThat(task.getConfigFile().get()).isEqualTo("agenteval.yaml");
assertThat(task.getReportFormats().get()).isEqualTo("console,json");
Expand All @@ -53,8 +52,7 @@ void datasetPathRequiredValidation() {
Project project = ProjectBuilder.builder().build();
project.getPluginManager().apply("com.agenteval.evaluate");

EvaluateTask task = (EvaluateTask) project.getTasks().findByName("agentEvaluate");
assertThat(task).isNotNull();
EvaluateTask task = (EvaluateTask) project.getTasks().getByName("agentEvaluate");

assertThat(task.getDatasetPath().isPresent()).isFalse();
}
Expand All @@ -68,14 +66,12 @@ void extensionOverridesWireToTask() {
project.getPluginManager().apply("com.agenteval.evaluate");

AgentEvalExtension ext = project.getExtensions()
.findByType(AgentEvalExtension.class);
assertThat(ext).isNotNull();
.getByType(AgentEvalExtension.class);

ext.getMetrics().set("Faithfulness,Correctness");
ext.getThreshold().set(0.8);

EvaluateTask task = (EvaluateTask) project.getTasks().findByName("agentEvaluate");
assertThat(task).isNotNull();
EvaluateTask task = (EvaluateTask) project.getTasks().getByName("agentEvaluate");
assertThat(task.getMetrics().get()).isEqualTo("Faithfulness,Correctness");
assertThat(task.getThreshold().get()).isEqualTo(0.8);
}
Expand Down
88 changes: 83 additions & 5 deletions agenteval-judge/src/main/java/com/agenteval/judge/JudgeModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import com.agenteval.judge.config.JudgeConfig;
import com.agenteval.judge.multi.MultiModelJudge;
import com.agenteval.judge.provider.AnthropicJudgeModel;
import com.agenteval.judge.provider.AzureOpenAiJudgeModel;
import com.agenteval.judge.provider.BedrockJudgeModel;
import com.agenteval.judge.provider.CustomHttpJudgeModel;
import com.agenteval.judge.provider.GoogleJudgeModel;
import com.agenteval.judge.provider.OllamaJudgeModel;
import com.agenteval.judge.provider.OpenAiJudgeModel;

Expand All @@ -15,20 +19,24 @@
* <pre>{@code
* var judge = JudgeModels.openai("gpt-4o");
* var judge = JudgeModels.anthropic("claude-sonnet-4-20250514");
* var judge = JudgeModels.google("gemini-1.5-pro");
* var judge = JudgeModels.azure(JudgeConfig.builder()
* .apiKey("...").model("my-deployment").baseUrl("https://myresource.openai.azure.com").build());
* var judge = JudgeModels.bedrock("anthropic.claude-3-sonnet-20240229-v1:0");
* var judge = JudgeModels.custom(JudgeConfig.builder()
* .model("my-model").baseUrl("http://localhost:8000").build());
* var judge = JudgeModels.ollama("llama3");
* var judge = JudgeModels.openai(JudgeConfig.builder()
* .apiKey("sk-...")
* .model("gpt-4o")
* .baseUrl("https://api.openai.com")
* .build());
* }</pre>
*/
public final class JudgeModels {

private static final String OPENAI_API_KEY_ENV = "OPENAI_API_KEY";
private static final String ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY";
private static final String GOOGLE_API_KEY_ENV = "GOOGLE_API_KEY";
private static final String OPENAI_BASE_URL = "https://api.openai.com";
private static final String ANTHROPIC_BASE_URL = "https://api.anthropic.com";
private static final String GOOGLE_BASE_URL = "https://generativelanguage.googleapis.com";
private static final String BEDROCK_BASE_URL = "https://bedrock-runtime.us-east-1.amazonaws.com";
private static final String OLLAMA_BASE_URL = "http://localhost:11434";

private JudgeModels() {}
Expand Down Expand Up @@ -71,6 +79,76 @@ public static JudgeModel anthropic(JudgeConfig config) {
return new AnthropicJudgeModel(config);
}

/**
* Creates a Google Gemini judge model using the given model ID.
* API key is resolved from the {@code GOOGLE_API_KEY} environment variable.
*/
public static JudgeModel google(String model) {
return google(JudgeConfig.builder()
.apiKey(resolveApiKey(GOOGLE_API_KEY_ENV, "Google"))
.model(model)
.baseUrl(GOOGLE_BASE_URL)
.build());
}

/**
* Creates a Google Gemini judge model with full configuration.
*/
public static JudgeModel google(JudgeConfig config) {
return new GoogleJudgeModel(config);
}

/**
* Creates an Azure OpenAI judge model using the given deployment name.
* API key is resolved from the {@code AZURE_OPENAI_API_KEY} environment variable.
* The base URL must point to your Azure OpenAI resource
* (e.g., {@code https://myresource.openai.azure.com}).
*/
public static JudgeModel azure(JudgeConfig config) {
return new AzureOpenAiJudgeModel(config);
}

/**
* Creates an Azure OpenAI judge model with a specific API version.
*/
public static JudgeModel azure(JudgeConfig config, String apiVersion) {
return new AzureOpenAiJudgeModel(config, apiVersion);
}

/**
* Creates an Amazon Bedrock judge model using the given model ID.
* AWS credentials are resolved from {@code AWS_ACCESS_KEY_ID} and
* {@code AWS_SECRET_ACCESS_KEY} environment variables.
*
* @param model the Bedrock model ID (e.g., {@code anthropic.claude-3-sonnet-20240229-v1:0})
*/
public static JudgeModel bedrock(String model) {
return bedrock(JudgeConfig.builder()
.model(model)
.baseUrl(BEDROCK_BASE_URL)
.build());
}

/**
* Creates an Amazon Bedrock judge model with full configuration.
* The base URL determines the AWS region
* (e.g., {@code https://bedrock-runtime.eu-west-1.amazonaws.com}).
*/
public static JudgeModel bedrock(JudgeConfig config) {
return new BedrockJudgeModel(config);
}

/**
* Creates a custom HTTP judge model for any OpenAI-compatible endpoint.
* No API key required — if one is set in the config, it will be sent
* as a {@code Bearer} token.
*
* <p>Use this for vLLM, LiteLLM, LocalAI, or any OpenAI-compatible server.</p>
*/
public static JudgeModel custom(JudgeConfig config) {
return new CustomHttpJudgeModel(config);
}

/**
* Creates an Ollama judge model using the given model ID.
* Defaults to {@code localhost:11434}. No API key required.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package com.agenteval.judge.provider;

import com.agenteval.core.model.TokenUsage;
import com.agenteval.judge.JudgeException;
import com.agenteval.judge.config.JudgeConfig;
import com.agenteval.judge.http.HttpJudgeClient;
import com.agenteval.judge.http.HttpJudgeRequest;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.Map;

/**
* Azure OpenAI judge model provider.
*
* <p>Sends requests to Azure's OpenAI-compatible endpoint at
* {@code {baseUrl}/openai/deployments/{model}/chat/completions?api-version=...}
* with {@code api-key} header authentication.</p>
*
* <p>The {@link JudgeConfig#getModel()} value is used as the deployment name.
* The API version defaults to {@code 2024-02-01} but can be customized
* by appending it to the base URL.</p>
*/
public final class AzureOpenAiJudgeModel extends AbstractHttpJudgeModel {

private static final ObjectMapper MAPPER = new ObjectMapper();
private static final String DEFAULT_API_VERSION = "2024-02-01";
private static final String COMPLETIONS_PATH =
"/openai/deployments/%s/chat/completions?api-version=%s";
private static final String SYSTEM_PROMPT =
"You are an evaluation judge. Respond ONLY with a JSON object "
+ "containing \"score\" (a number between 0.0 and 1.0) "
+ "and \"reason\" (a brief explanation).";

private final String apiVersion;

public AzureOpenAiJudgeModel(JudgeConfig config) {
this(config, DEFAULT_API_VERSION);
}

public AzureOpenAiJudgeModel(JudgeConfig config, String apiVersion) {
super(config);
if (config.getApiKey() == null || config.getApiKey().isBlank()) {
throw new JudgeException("Azure OpenAI requires a non-null API key");
}
this.apiVersion = apiVersion != null ? apiVersion : DEFAULT_API_VERSION;
}

AzureOpenAiJudgeModel(JudgeConfig config, String apiVersion, HttpJudgeClient client) {
super(config, client);
this.apiVersion = apiVersion != null ? apiVersion : DEFAULT_API_VERSION;
}

String getApiVersion() {
return apiVersion;
}

@Override
protected HttpJudgeRequest buildRequest(String prompt) {
try {
var body = MAPPER.createObjectNode();
body.put("temperature", config.getTemperature());

var responseFormat = MAPPER.createObjectNode();
responseFormat.put("type", "json_object");
body.set("response_format", responseFormat);

var messages = body.putArray("messages");

var systemMsg = messages.addObject();
systemMsg.put("role", "system");
systemMsg.put("content", SYSTEM_PROMPT);

var userMsg = messages.addObject();
userMsg.put("role", "user");
userMsg.put("content", prompt);

String url = config.getBaseUrl()
+ String.format(COMPLETIONS_PATH, config.getModel(), apiVersion);
return new HttpJudgeRequest(
url,
Map.of("api-key", config.getApiKey()),
MAPPER.writeValueAsString(body));
} catch (Exception e) {
throw new JudgeException("Failed to build Azure OpenAI request", e);
}
}

@Override
protected String extractContent(String responseBody) {
JsonNode root = parseJson(responseBody);
JsonNode choices = root.path("choices");
if (choices.isEmpty()) {
throw new JudgeException("No choices in Azure OpenAI response");
}
return choices.get(0).path("message").path("content").asText("");
}

@Override
protected TokenUsage extractTokenUsage(String responseBody) {
JsonNode root = parseJson(responseBody);
JsonNode usage = root.path("usage");
if (usage.isMissingNode()) {
return null;
}
return new TokenUsage(
usage.path("prompt_tokens").asInt(0),
usage.path("completion_tokens").asInt(0),
usage.path("total_tokens").asInt(0));
}
}
Loading