Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions lib/active_agent/generation_provider/anthropic_provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def generate(prompt)
end

def chat_prompt(parameters: prompt_parameters)
parameters[:stream] = provider_stream if prompt.options[:stream] || config["stream"]
if prompt.options[:stream] || config["stream"]
parameters[:stream] = provider_stream
@streaming_request_params = parameters
end

chat_response(@client.messages(parameters: parameters))
chat_response(@client.messages(parameters: parameters), parameters)
end

protected
Expand Down Expand Up @@ -120,7 +123,7 @@ def convert_role(role)
end
end

def chat_response(response)
def chat_response(response, request_params = nil)
return @response if prompt.options[:stream]

content = response["content"].first["text"]
Expand All @@ -137,7 +140,8 @@ def chat_response(response)
@response = ActiveAgent::GenerationProvider::Response.new(
prompt: prompt,
message: message,
raw_response: response
raw_response: response,
raw_request: request_params
)
end

Expand Down
46 changes: 35 additions & 11 deletions lib/active_agent/generation_provider/open_ai_provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def process_stream_chunk(chunk, message, agent_stream)
elsif chunk.dig("choices", 0, "delta", "tool_calls") && chunk.dig("choices", 0, "delta", "role")
message = handle_message(chunk.dig("choices", 0, "delta"))
prompt.messages << message
@response = ActiveAgent::GenerationProvider::Response.new(prompt:, message:)
@response = ActiveAgent::GenerationProvider::Response.new(
prompt:,
message:,
raw_response: chunk,
raw_request: @streaming_request_params
)
end

if chunk.dig("choices", 0, "finish_reason")
Expand All @@ -92,18 +97,23 @@ def format_image_content(message)
# The format_tools method comes from ToolManagement module
# The provider_messages method comes from MessageFormatting module

def chat_response(response)
def chat_response(response, request_params = nil)
return @response if prompt.options[:stream]
message_json = response.dig("choices", 0, "message")
message_json["id"] = response.dig("id") if message_json["id"].blank?
message = handle_message(message_json)

update_context(prompt: prompt, message: message, response: response)

@response = ActiveAgent::GenerationProvider::Response.new(prompt: prompt, message: message, raw_response: response)
@response = ActiveAgent::GenerationProvider::Response.new(
prompt: prompt,
message: message,
raw_response: response,
raw_request: request_params
)
end

def responses_response(response)
def responses_response(response, request_params = nil)
message_json = response["output"].find { |output_item| output_item["type"] == "message" }
message_json["id"] = response.dig("id") if message_json["id"].blank?

Expand All @@ -116,7 +126,12 @@ def responses_response(response)
content_type: prompt.output_schema.present? ? "application/json" : "text/plain",
)

@response = ActiveAgent::GenerationProvider::Response.new(prompt: prompt, message: message, raw_response: response)
@response = ActiveAgent::GenerationProvider::Response.new(
prompt: prompt,
message: message,
raw_response: response,
raw_request: request_params
)
end

def handle_message(message_json)
Expand All @@ -133,13 +148,16 @@ def handle_message(message_json)
# handle_actions is now provided by ToolManagement module

def chat_prompt(parameters: prompt_parameters)
parameters[:stream] = provider_stream if prompt.options[:stream] || config["stream"]
chat_response(@client.chat(parameters: parameters))
if prompt.options[:stream] || config["stream"]
parameters[:stream] = provider_stream
@streaming_request_params = parameters
end
chat_response(@client.chat(parameters: parameters), parameters)
end

def responses_prompt(parameters: responses_parameters)
# parameters[:stream] = provider_stream if prompt.options[:stream] || config["stream"]
responses_response(@client.responses.create(parameters: parameters))
responses_response(@client.responses.create(parameters: parameters), parameters)
end

def responses_parameters(model: @prompt.options[:model] || @model_name, messages: @prompt.messages, temperature: @prompt.options[:temperature] || @config["temperature"] || 0.7, tools: @prompt.actions, structured_output: @prompt.output_schema)
Expand All @@ -158,14 +176,20 @@ def embeddings_parameters(input: prompt.message.content, model: "text-embedding-
}
end

def embeddings_response(response)
def embeddings_response(response, request_params = nil)
message = ActiveAgent::ActionPrompt::Message.new(content: response.dig("data", 0, "embedding"), role: "assistant")

@response = ActiveAgent::GenerationProvider::Response.new(prompt: prompt, message: message, raw_response: response)
@response = ActiveAgent::GenerationProvider::Response.new(
prompt: prompt,
message: message,
raw_response: response,
raw_request: request_params
)
end

def embeddings_prompt(parameters:)
embeddings_response(@client.embeddings(parameters: embeddings_parameters))
params = embeddings_parameters
embeddings_response(@client.embeddings(parameters: params), params)
end
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def initialize(config)
uri_base: "https://openrouter.ai/api/v1",
access_token: @access_token,
log_errors: Rails.env.development?,
default_headers: openrouter_headers
extra_headers: openrouter_headers
)
end

Expand Down
46 changes: 38 additions & 8 deletions lib/active_agent/generation_provider/response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@
module ActiveAgent
module GenerationProvider
class Response
attr_reader :message, :prompt, :raw_response
attr_reader :message, :prompt, :raw_response, :raw_request
attr_accessor :metadata

def initialize(prompt:, message: nil, raw_response: nil, metadata: nil)
def initialize(prompt:, message: nil, raw_response: nil, raw_request: nil, metadata: nil)
@prompt = prompt
@message = message || prompt.message
@raw_response = raw_response
@raw_request = sanitize_request(raw_request)
@metadata = metadata || {}
end

# Extract usage statistics from the raw response
def usage
return nil unless @raw_response

# OpenAI/OpenRouter format
# Most providers store usage in the same format
if @raw_response.is_a?(Hash) && @raw_response["usage"]
@raw_response["usage"]
# Anthropic format
elsif @raw_response.is_a?(Hash) && @raw_response["usage"]
@raw_response["usage"]
else
nil
end
end

Expand All @@ -40,6 +36,40 @@ def completion_tokens
def total_tokens
usage&.dig("total_tokens")
end

private

def sanitize_request(request)
return nil if request.nil?
return request unless request.is_a?(Hash)

# Deep clone the request to avoid modifying the original
sanitized = request.deep_dup

# Sanitize any string values in the request
sanitize_hash_values(sanitized)
end

def sanitize_hash_values(hash)
hash.each do |key, value|
case value
when String
# Use ActiveAgent's sanitize_credentials to replace sensitive data
hash[key] = ActiveAgent.sanitize_credentials(value)
when Hash
sanitize_hash_values(value)
when Array
value.each_with_index do |item, index|
if item.is_a?(String)
value[index] = ActiveAgent.sanitize_credentials(item)
elsif item.is_a?(Hash)
sanitize_hash_values(item)
end
end
end
end
hash
end
end
end
end
2 changes: 1 addition & 1 deletion test/dummy/config/credentials.yml.enc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7RXYXaZ+9Xqohw1oMp++FYX/bKiNI9tBusC5oBNAPdj12WYZmBn+d4GZk1whE0gqwJjscVDk0dJgIXt3sZlqDHIe7pd8n4EQsnZ1mPXk2R5QVyEPuiIfISBhx1skCgqXI0ga9HBGalGoUxQLtEO2rbCYea7YCfOnwFqLsZ9ZD8ciiL7hLY3jHafhRo7CuRYcBpzOlaZLlB574nLphtxgsL0xxAi8t7bdueLpDegAxSHmpZNzWmMkNcC7W9UmCVlyieP1jCAhFkuS5JMG3WpPYo1Ft6PsvYf8rcFJEhr8s/L75B6MFZFW/45YdHRWBron9CJTNxMIdfzY3E30Bb2zoL3juv0BRbiXC3PkkB1HT+cyxRPR8XYASVHLOLH0enFE+839OFI5edOFDtMQsPoO8BSbcnji45Xc3ISORDInOBEWtJ88vSDU2S9ufY79szgWydVBYZuOW5g8ayFGG4gWHusnivRLPfjigkL8/44tb7Fmuh2vdewNma+4c1Yvj7xFMO8K4cba0xsysdPYUaCxh1Ys6tk9NZ4Y7A1+QwRFxSSHXJtuP5JsTs5vvZxc5BEPXOxv81dwmAObmSPZBGodP2XqfRb8ludCFUW/v7g24ZZx1SVTzfk7LW6bhF6+oQMkUckeXeoJrTEZBlDUuOXbrtmd+CudZukXAz8JIIhJx2ZRFNkI/yl88Usmo3vIV1topRzm0kKTdXXg0q2/VO+uOY0rF6ZW/eCYyttrcl2KSEFLTtUneKsVJczizY4L+wqe5b+tpOUHzv/EuVbsbiXZU77S/bcbWGLbb6juJgtMSUkBzAYeY5jPwdLQ2l5Vn5yCd8e4FmHn/l3cIXjrtyvjxulUnm/PkZRb0YggC8795CL/DDpftCAZqmCKV34ypn+LMhqHh1fKYuWbnuBvX0PaIXb42u2fbCrFUvAa1KRCK6OKp7j6j7H1i9eZofpQZEhOGBYOOZaZV17LVcCOO4WuskNR6UxM/BnmJ3oTMRw+wEi+KEFVKvYWDTQeepnk6p6GjVCjbr3w2VDXtpjWb3vLWh/QMwHGpPzNltmroZZDAsEsxXXZQ+3nsp4tQyeLO8w8xE1vu1lKgcFxHaQ4IX/bXSaUTOJ0OovUEL7JjY63Sr9+f9KLTWbbgQ/QZrBwoGfwB364ylOt1Vfojqz6ATasdcgQXe38L548mmEDBrJ1idfj0mqNznzYpzB7/waLK7SG92pEF6b67uRxXG73PZfCVrRGXHKwu2g27J1+bUT26Ojpa0qzHEIEEUabufJSxNKKEAXqgv/DCkXyq18YrZx+NA==--85+CXyjvfbFJi6BM--CHQakeLJ4Z0qHE/8QWnxVQ==
j1oFzRL/Z5P7MYpUZlH1tiicXPeonq+0LknaUFGHhsrIGUkk7WXGSQN89AeAehInTbXvEA7C7QZsn0xbj7xdYRAb7BVoJk1MsREM7TkfWmXnvQN66jomSsu/rO9StAoVg1PvfUANDl1ubwrtfuNRydLKCR4+vb1ZZKYneFXXBc5mDHvEPKtEnMypMeUC4LanqY2iqPTtPFLBF/3M20ofctm0s5yRcfVMQO+9KwKvZ/nkNne1LLNSbOmM3243vw+1CqO4Kw54qDF1idMamdU3tqoS64CCtcVmUid7LuYdeK2rYOk2wASNgTaascIBdsl6ZdTXq2ODj4b7WJrrRaA1JYPccwQYB7dFaKT0pkI6vvBjA1LZuFxl+VTWSpjmZCKkw0rv99ckfbV1hdkRAOHB5o3g4ceksOuFDpgJaACptioie9IB96FTwxEIaigYt2GkVFDlqt8F9N+z98cg2FW3xxWR9373hBY7i+huU2i9f8fSxvmMrmqsBJ2B64UmDYSLfYEaOevfZ1EuEznJFHAomum2zmxT5yOzNrZFlstoC793KdpFSnQugBzpa1ULWuKrl5eybdiD97L/5U0nsFEltoyr+8gDxaZpIxgrOxTyGOm8khfaPQJlia0hNvuzM5VvglpiIOY+BsFhe5xBenP0Q068g65fbXiUqVk1mXz8yW8GyGKQs3MAQj9h6td9z1f3Bn77BpRFh4pElkn07jmF8wyLQLKSrJxq37gcDmu03wwFLQ32dBOAQbbikIoNcZE0eVneZARzohdzffZzk/60g3fUNrzi0d/UXQZt0PKuBrGeQfWxkllQUEH7UqTYCzRM7PD57weSAmD/8m3N6sFOFIwsdQ/17PwTi04gdPW9YDZbVXrtw4dmw+pRPqs8bORHu+seKSIuYpBZoZuq9go0QMqKW4lVNU8eI5V3u8MJ5K9i92cQZg5Y5zpWiNpgQQhTFAq66Wv1yzqV4hGQV+OGGdQt1KHSWEsmScdUPk33jRfy8POw/iMt2NvFLcCgdsoVnm1zDz9CnKnVzJQ6opGPoIg/ncIu2ErcICGrikeV8JFgBxXpb1K2tOxkyaV8t6DELJrkYcXQ1hHy2c2AfzwvIyYB+wRZuVL7m0EngbgGDVdykgzD+XR5aGLlwc2fXZb47GkF6rrsrNXJlc7FCIkpPAkVmIy3RoSGNri0DwwKHTY3Kf57N2oTP95Ae7nPPVs99WMSy4eL2Wooy9gxM0RfxaCYirakM2OmLYFZ6LMkmXWCaeZcoiPbdQojjDhrY3PRyQx2OuwevxQzazaAFwadqZoWMSkg2LaSz/8eMa5IwM5lfoorF9ZTO3iTXyDAcHCXc+IzKg9DJ6oMWnF22VSTUQu38XswCeUquy9jkN8hSFPwE34Qhx3CesKzmIuEm0m8w+oZWigDnU2/iGf/xwv3zsR9hafoedDuv8oh6FA017VWTel/v7HZHYgcLZMB7V1Kj3ZBdn0O7n6eVd+fpFGAhClK/N1vIsgCcmVwJYuqKcWYUapQ0CQN4U031oF4KknwHcWHTkVrtzLlfh1RKXaZ8prVr/QX/fKlCrm6V5Lvfml1PA==--AgUCVbFPGIA1BP05--YI0wFyd8L6DcXLi7cDU6Cg==
155 changes: 155 additions & 0 deletions test/generation_provider/anthropic_raw_request_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
require "test_helper"
require "active_agent/generation_provider/anthropic_provider"
require "active_agent/action_prompt/prompt"

module ActiveAgent
module GenerationProvider
class AnthropicRawRequestTest < ActiveSupport::TestCase
setup do
@config = {
"api_key" => "test-key",
"model" => "claude-3-opus-20240229"
}
@provider = AnthropicProvider.new(@config)

@prompt = ActiveAgent::ActionPrompt::Prompt.new(
messages: [
ActiveAgent::ActionPrompt::Message.new(
content: "Hello, Claude!",
role: "user"
)
],
actions: [],
options: {},
output_schema: nil
)
end

test "chat_response includes raw_request when provided" do
mock_response = {
"id" => "msg-123",
"content" => [
{
"type" => "text",
"text" => "Hello! I'm Claude. How can I assist you today?"
}
],
"stop_reason" => "end_turn",
"usage" => {
"input_tokens" => 10,
"output_tokens" => 12
}
}

request_params = {
model: "claude-3-opus-20240229",
messages: [ { role: "user", content: "Hello, Claude!" } ],
max_tokens: 1024,
temperature: 0.7
}

@provider.instance_variable_set(:@prompt, @prompt)
response = @provider.send(:chat_response, mock_response, request_params)

assert_not_nil response
assert_equal request_params, response.raw_request
assert_equal mock_response, response.raw_response
end

test "chat_response with tool use includes raw_request" do
mock_response = {
"id" => "msg-456",
"content" => [
{
"type" => "text",
"text" => "I'll help you with that calculation."
},
{
"type" => "tool_use",
"id" => "tool-789",
"name" => "calculator",
"input" => { "expression" => "2 + 2" }
}
],
"stop_reason" => "tool_use"
}

request_params = {
model: "claude-3-opus-20240229",
messages: [ { role: "user", content: "What is 2 + 2?" } ],
tools: [
{
name: "calculator",
description: "Performs calculations",
input_schema: {
type: "object",
properties: {
expression: { type: "string" }
}
}
}
],
max_tokens: 1024
}

@provider.instance_variable_set(:@prompt, @prompt)
response = @provider.send(:chat_response, mock_response, request_params)

assert_not_nil response
assert_equal request_params, response.raw_request
assert_equal mock_response, response.raw_response
assert response.message.action_requested
end

test "streaming request params are captured" do
request_params = {
model: "claude-3-opus-20240229",
messages: [ { role: "user", content: "Stream test" } ],
stream: true,
max_tokens: 1024
}

@provider.instance_variable_set(:@prompt, @prompt)

# Simulate setting streaming params like in chat_prompt
@provider.instance_variable_set(:@streaming_request_params, request_params)

assert_equal request_params, @provider.instance_variable_get(:@streaming_request_params)
end

test "response includes metadata alongside raw_request and raw_response" do
mock_response = {
"id" => "msg-meta-123",
"content" => [
{
"type" => "text",
"text" => "Response with metadata"
}
],
"stop_reason" => "end_turn",
"model" => "claude-3-opus-20240229",
"usage" => {
"input_tokens" => 5,
"output_tokens" => 4
}
}

request_params = {
model: "claude-3-opus-20240229",
messages: [ { role: "user", content: "Test" } ],
max_tokens: 100
}

@provider.instance_variable_set(:@prompt, @prompt)
response = @provider.send(:chat_response, mock_response, request_params)

assert_not_nil response
assert_equal request_params, response.raw_request
assert_equal mock_response, response.raw_response

# Response should also have metadata
assert_instance_of Hash, response.metadata
end
end
end
end
Loading