diff --git a/docs/internals/llm-gateway.md b/docs/internals/llm-gateway.md index 22c7fa8..44be897 100644 --- a/docs/internals/llm-gateway.md +++ b/docs/internals/llm-gateway.md @@ -63,10 +63,10 @@ The sequence is: 1. `F::call_native()` chooses the endpoint path and request body 2. `Gateway::call_chat_native()` executes the HTTP POST against the provider instance base URL -3. for complete calls, `F::parse_native_response()` parses the JSON response into `F::Response` +3. for complete calls, `F::parse_native_response()` parses the JSON response into `F::Response`, then `F::response_usage()` can extract a `Usage` snapshot from that typed response 4. for stream calls, `NativeStream` converts provider-native chunks into `F::StreamChunk` and sends final `Usage` through a oneshot channel -The gateway currently returns `Usage::default()` for native complete calls because there is not yet a generic format hook for extracting usage out of arbitrary native response types. +Native complete calls no longer hard-code `Usage::default()`. Formats can now report native complete-call usage through `ChatFormat::response_usage()`, while formats that keep the default hook still return an empty `Usage` value. ## `ChatResponse` @@ -97,9 +97,8 @@ This module does not attempt to finish the full Layer 3 design. - `SessionStore` is not wired yet - `chat_completion()` and `messages()` are implemented as convenience helpers today - `responses()` remains deferred until its corresponding format lands -- `AnthropicMessagesFormat` still rejects non-native hub streaming; only its native provider path can stream today - only `StreamReaderKind::Sse` is wired today; `AwsEventStream` and `JsonArrayStream` are still deferred -- native complete-call usage extraction is still format-specific future work +- native complete-call usage reporting depends on each format implementing `ChatFormat::response_usage()`; formats that do not override it still return empty usage ## Why This Slice Exists diff --git a/src/gateway/formats/anthropic_messages.rs b/src/gateway/formats/anthropic_messages.rs index ea516d5..1bda4df 100644 --- a/src/gateway/formats/anthropic_messages.rs +++ b/src/gateway/formats/anthropic_messages.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use serde_json::{Value, json}; use crate::gateway::{ @@ -7,24 +9,48 @@ use crate::gateway::{ anthropic::{ AnthropicContent, AnthropicContentBlock, AnthropicMessage, AnthropicMessagesRequest, AnthropicMessagesResponse, AnthropicStreamEvent, AnthropicTool, AnthropicToolChoice, - AnthropicUsage, CacheControl, ImageSource, SystemPrompt, + AnthropicUsage, CacheControl, ContentDelta, DeltaUsage, ImageSource, MessageDelta, + MessageStartPayload, MessageStartUsage, SystemPrompt, }, - common::{AnthropicMessagesExtras, BridgeContext}, + common::{AnthropicMessagesExtras, BridgeContext, Usage}, openai::{ - ChatCompletionRequest, ChatCompletionResponse, ChatCompletionUsage, ChatMessage, - ContentPart, FunctionCall, FunctionDefinition, ImageUrl, MessageContent, StopCondition, - Tool, ToolCall, ToolChoice, ToolChoiceFunction, + ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionUsage, ChatMessage, ContentPart, FunctionCall, FunctionDefinition, + ImageUrl, MessageContent, StopCondition, Tool, ToolCall, ToolChoice, + ToolChoiceFunction, }, }, }; pub struct AnthropicMessagesFormat; +/// Streaming bridge state for Anthropic message assembly. +/// It tracks message/block lifecycle, token counters, stop reason, and tool-to-block mappings while hub chunks are converted incrementally. +#[derive(Debug, Clone, Default)] +pub struct AnthropicBridgeState { + message_started: bool, + current_block_index: usize, + current_block_type: Option, + current_block_open: bool, + stop_reason: Option, + input_tokens: Option, + output_tokens: Option, + cache_creation_input_tokens: Option, + cache_read_input_tokens: Option, + tool_block_map: HashMap, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AnthropicBlockType { + Text, + ToolUse, +} + impl ChatFormat for AnthropicMessagesFormat { type Request = AnthropicMessagesRequest; type Response = AnthropicMessagesResponse; type StreamChunk = AnthropicStreamEvent; - type BridgeState = (); + type BridgeState = AnthropicBridgeState; type NativeStreamState = AnthropicMessagesNativeStreamState; fn name() -> &'static str { @@ -40,16 +66,18 @@ impl ChatFormat for AnthropicMessagesFormat { } fn to_hub(req: &Self::Request) -> Result<(ChatCompletionRequest, BridgeContext)> { + if req.cache_control.is_some() { + return Err(GatewayError::Bridge( + "Anthropic top-level cache_control is not supported by hub bridging".into(), + )); + } + ensure_no_message_cache_controls_for_hub(&req.messages)?; + let (mut messages, system_cache_control) = system_prompt_to_hub_messages(req.system.as_ref())?; for message in &req.messages { messages.extend(anthropic_message_to_hub_messages(message)?); } - if req.stream.unwrap_or(false) { - return Err(GatewayError::Bridge( - "Anthropic messages hub streaming bridge is not implemented yet".into(), - )); - } let metadata = req .metadata @@ -97,13 +125,43 @@ impl ChatFormat for AnthropicMessagesFormat { } fn from_hub_stream( - _chunk: &crate::gateway::types::openai::ChatCompletionChunk, - _state: &mut Self::BridgeState, + chunk: &ChatCompletionChunk, + state: &mut Self::BridgeState, _ctx: &BridgeContext, ) -> Result> { - Err(GatewayError::Bridge( - "Anthropic messages hub streaming bridge is not implemented yet".into(), - )) + anthropic_bridge_state_machine(chunk, state) + } + + fn stream_end_events( + state: &mut Self::BridgeState, + _ctx: &BridgeContext, + ) -> Vec { + if !state.message_started { + return vec![]; + } + + let mut events = Vec::new(); + if state.current_block_open { + events.push(AnthropicStreamEvent::ContentBlockStop { + index: state.current_block_index, + }); + state.current_block_open = false; + state.current_block_type = None; + } + events.push(AnthropicStreamEvent::MessageDelta { + delta: MessageDelta { + stop_reason: state.stop_reason.clone(), + stop_sequence: None, + }, + usage: DeltaUsage { + output_tokens: state.output_tokens, + input_tokens: state.input_tokens, + cache_creation_input_tokens: state.cache_creation_input_tokens, + cache_read_input_tokens: state.cache_read_input_tokens, + }, + }); + events.push(AnthropicStreamEvent::MessageStop); + events } fn native_support(provider: &dyn ProviderCapabilities) -> Option> @@ -147,7 +205,20 @@ impl ChatFormat for AnthropicMessagesFormat { }); }; - handler.transform_anthropic_messages_stream_chunk(raw, state) + let events = handler.transform_anthropic_messages_stream_chunk(raw, state)?; + for event in &events { + update_native_usage_from_event(event, state); + } + + Ok(events) + } + + fn native_usage(state: &Self::NativeStreamState) -> Usage { + state.usage.clone() + } + + fn response_usage(response: &Self::Response) -> Usage { + anthropic_usage_to_common_usage(&response.usage) } fn parse_native_response(native: &NativeHandler<'_>, body: Value) -> Result @@ -319,13 +390,15 @@ fn anthropic_assistant_blocks_to_hub( text_segments.push(text.clone()); rich_parts.push(ContentPart::Text { text: text.clone() }); } - AnthropicContentBlock::Image { source } => { + AnthropicContentBlock::Image { source, .. } => { has_non_text_part = true; rich_parts.push(ContentPart::ImageUrl { image_url: anthropic_source_to_openai_image_url(source)?, }); } - AnthropicContentBlock::ToolUse { id, name, input } => { + AnthropicContentBlock::ToolUse { + id, name, input, .. + } => { tool_calls.push(ToolCall { id: id.clone(), r#type: "function".into(), @@ -374,7 +447,7 @@ fn anthropic_blocks_to_openai_content( text_segments.push(text.clone()); rich_parts.push(ContentPart::Text { text: text.clone() }); } - AnthropicContentBlock::Image { source } => { + AnthropicContentBlock::Image { source, .. } => { has_non_text_part = true; rich_parts.push(ContentPart::ImageUrl { image_url: anthropic_source_to_openai_image_url(source)?, @@ -510,6 +583,7 @@ fn openai_message_to_anthropic_blocks(message: &ChatMessage) -> Result Ok(AnthropicContentBlock::Image { source: openai_image_url_to_anthropic_source(&image_url.url)?, + cache_control: None, }), }) .collect(), @@ -557,6 +632,115 @@ fn openai_usage_to_anthropic(usage: Option<&ChatCompletionUsage>) -> AnthropicUs output_tokens: usage.completion_tokens, cache_creation_input_tokens: 0, cache_read_input_tokens: cached_tokens, + cache_creation: None, + } +} + +fn anthropic_usage_to_common_usage(usage: &AnthropicUsage) -> Usage { + let input_tokens = + usage.input_tokens + usage.cache_creation_input_tokens + usage.cache_read_input_tokens; + + Usage { + input_tokens: Some(input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(input_tokens + usage.output_tokens), + cache_creation_input_tokens: Some(usage.cache_creation_input_tokens), + cache_read_input_tokens: Some(usage.cache_read_input_tokens), + ..Default::default() + } +} + +fn anthropic_delta_usage_to_common_usage(usage: &DeltaUsage, previous: &Usage) -> Usage { + let cache_creation_input_tokens = usage + .cache_creation_input_tokens + .or(previous.cache_creation_input_tokens); + let cache_read_input_tokens = usage + .cache_read_input_tokens + .or(previous.cache_read_input_tokens); + let input_tokens = usage.input_tokens.map(|input_tokens| { + input_tokens + + cache_creation_input_tokens.unwrap_or(0) + + cache_read_input_tokens.unwrap_or(0) + }); + + Usage { + input_tokens, + output_tokens: usage.output_tokens, + total_tokens: input_tokens + .zip(usage.output_tokens) + .map(|(input_tokens, output_tokens)| input_tokens + output_tokens), + cache_creation_input_tokens: usage.cache_creation_input_tokens, + cache_read_input_tokens: usage.cache_read_input_tokens, + ..Default::default() + } +} + +fn update_native_usage_from_event( + event: &AnthropicStreamEvent, + state: &mut AnthropicMessagesNativeStreamState, +) { + match event { + AnthropicStreamEvent::MessageStart { message } => { + state + .usage + .merge(&anthropic_message_start_usage_to_common_usage( + &message.usage, + )); + } + AnthropicStreamEvent::MessageDelta { usage, .. } => { + state + .usage + .merge(&anthropic_delta_usage_to_common_usage(usage, &state.usage)); + } + _ => {} + } +} + +fn ensure_no_message_cache_controls_for_hub(messages: &[AnthropicMessage]) -> Result<()> { + for message in messages { + let AnthropicContent::Blocks(blocks) = &message.content else { + continue; + }; + + if blocks + .iter() + .any(|block| anthropic_block_cache_control(block).is_some()) + { + return Err(GatewayError::Bridge( + "Anthropic per-block cache_control on user/assistant messages is not supported by hub bridging" + .into(), + )); + } + } + + Ok(()) +} + +fn anthropic_block_cache_control(block: &AnthropicContentBlock) -> Option<&CacheControl> { + match block { + AnthropicContentBlock::Text { cache_control, .. } + | AnthropicContentBlock::Image { cache_control, .. } + | AnthropicContentBlock::ToolUse { cache_control, .. } + | AnthropicContentBlock::ToolResult { cache_control, .. } => cache_control.as_ref(), + } +} + +fn anthropic_message_start_usage_to_common_usage(usage: &MessageStartUsage) -> Usage { + let input_tokens = usage.input_tokens.map(|input_tokens| { + input_tokens + + usage.cache_creation_input_tokens.unwrap_or(0) + + usage.cache_read_input_tokens.unwrap_or(0) + }); + + Usage { + input_tokens, + output_tokens: usage.output_tokens, + total_tokens: input_tokens + .zip(usage.output_tokens) + .map(|(input_tokens, output_tokens)| input_tokens + output_tokens), + cache_creation_input_tokens: usage.cache_creation_input_tokens, + cache_read_input_tokens: usage.cache_read_input_tokens, + ..Default::default() } } @@ -617,6 +801,207 @@ fn hub_message(role: &str, content: Option) -> ChatMessage { } } +fn anthropic_bridge_state_machine( + chunk: &ChatCompletionChunk, + state: &mut AnthropicBridgeState, +) -> Result> { + if chunk.choices.len() > 1 { + return Err(GatewayError::Bridge( + "Anthropic stream bridge cannot represent multiple OpenAI choices".into(), + )); + } + + if let Some(usage) = &chunk.usage { + let cached_tokens = usage + .prompt_tokens_details + .as_ref() + .and_then(|details| details.cached_tokens) + .unwrap_or(0); + state.input_tokens = Some(usage.prompt_tokens.saturating_sub(cached_tokens)); + state.output_tokens = Some(usage.completion_tokens); + state.cache_creation_input_tokens = Some(0); + state.cache_read_input_tokens = Some(cached_tokens); + } + + let Some(choice) = chunk.choices.first() else { + return Ok(vec![]); + }; + + if choice.index != 0 { + return Err(GatewayError::Bridge(format!( + "Anthropic stream bridge only supports OpenAI choice index 0, got {}", + choice.index + ))); + } + + if let Some(role) = choice.delta.role.as_deref() + && role != "assistant" + { + return Err(GatewayError::Bridge(format!( + "Anthropic stream bridge requires assistant deltas, got {}", + role + ))); + } + + let mut events = Vec::new(); + if !state.message_started { + state.message_started = true; + events.push(AnthropicStreamEvent::MessageStart { + message: MessageStartPayload { + id: chunk.id.clone(), + r#type: "message".into(), + role: "assistant".into(), + model: chunk.model.clone(), + usage: MessageStartUsage { + input_tokens: state.input_tokens, + output_tokens: state.output_tokens, + cache_creation_input_tokens: state.cache_creation_input_tokens, + cache_read_input_tokens: state.cache_read_input_tokens, + cache_creation: None, + }, + }, + }); + } + + if let Some(content) = choice.delta.content.as_ref() + && !content.is_empty() + { + ensure_text_block_open(state, &mut events); + events.push(AnthropicStreamEvent::ContentBlockDelta { + index: state.current_block_index, + delta: ContentDelta::TextDelta { + text: content.clone(), + }, + }); + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let block_index = if let Some(&block_index) = state.tool_block_map.get(&tool_call.index) + { + if !state.current_block_open + || state.current_block_type != Some(AnthropicBlockType::ToolUse) + || state.current_block_index != block_index + { + return Err(GatewayError::Bridge( + "Anthropic stream bridge does not support interleaved OpenAI tool call deltas" + .into(), + )); + } + block_index + } else { + let tool_type = tool_call.r#type.as_deref().ok_or_else(|| { + GatewayError::Bridge( + "Anthropic stream bridge requires tool call types on the first delta" + .into(), + ) + })?; + if tool_type != "function" { + return Err(GatewayError::Bridge(format!( + "Anthropic stream bridge only supports function tool calls, got {}", + tool_type + ))); + } + let tool_id = tool_call.id.as_deref().ok_or_else(|| { + GatewayError::Bridge( + "Anthropic stream bridge requires tool call ids on the first delta".into(), + ) + })?; + let function = tool_call.function.as_ref().ok_or_else(|| { + GatewayError::Bridge( + "Anthropic stream bridge requires function metadata on the first tool delta" + .into(), + ) + })?; + let tool_name = function.name.as_deref().ok_or_else(|| { + GatewayError::Bridge( + "Anthropic stream bridge requires function names on the first tool delta" + .into(), + ) + })?; + + close_current_block(state, &mut events); + let block_index = state.current_block_index; + state.tool_block_map.insert(tool_call.index, block_index); + state.current_block_type = Some(AnthropicBlockType::ToolUse); + state.current_block_open = true; + events.push(AnthropicStreamEvent::ContentBlockStart { + index: block_index, + content_block: AnthropicContentBlock::ToolUse { + id: tool_id.to_string(), + name: tool_name.to_string(), + input: json!({}), + cache_control: None, + }, + }); + block_index + }; + + if let Some(arguments) = tool_call + .function + .as_ref() + .and_then(|function| function.arguments.as_ref()) + && !arguments.is_empty() + { + events.push(AnthropicStreamEvent::ContentBlockDelta { + index: block_index, + delta: ContentDelta::InputJsonDelta { + partial_json: arguments.clone(), + }, + }); + } + } + } + + if let Some(finish_reason) = choice.finish_reason.as_deref() { + state.stop_reason = Some(openai_finish_reason_to_anthropic_stream(finish_reason)); + } + + Ok(events) +} + +fn ensure_text_block_open( + state: &mut AnthropicBridgeState, + events: &mut Vec, +) { + if state.current_block_open && state.current_block_type == Some(AnthropicBlockType::Text) { + return; + } + + close_current_block(state, events); + events.push(AnthropicStreamEvent::ContentBlockStart { + index: state.current_block_index, + content_block: AnthropicContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }); + state.current_block_type = Some(AnthropicBlockType::Text); + state.current_block_open = true; +} + +fn close_current_block(state: &mut AnthropicBridgeState, events: &mut Vec) { + if !state.current_block_open { + return; + } + + events.push(AnthropicStreamEvent::ContentBlockStop { + index: state.current_block_index, + }); + state.current_block_index += 1; + state.current_block_type = None; + state.current_block_open = false; +} + +fn openai_finish_reason_to_anthropic_stream(finish_reason: &str) -> String { + match finish_reason { + "stop" => "end_turn".into(), + "length" => "max_tokens".into(), + "tool_calls" => "tool_use".into(), + other => other.to_string(), + } +} + #[cfg(test)] mod tests { use serde_json::json; @@ -694,20 +1079,94 @@ mod tests { } #[test] - fn request_to_hub_rejects_streaming_until_hub_stream_bridge_lands() { - let request: AnthropicMessagesRequest = serde_json::from_value(json!({ - "model": "claude-3-5-sonnet-20241022", - "max_tokens": 256, - "messages": [{"role": "user", "content": "hello"}], - "stream": true - })) + fn from_hub_stream_emits_text_lifecycle_and_end_events() { + let mut state = super::AnthropicBridgeState::default(); + let first_chunk: crate::gateway::types::openai::ChatCompletionChunk = + serde_json::from_value(json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "role": "assistant", + "content": "hello" + }, + "finish_reason": null + }] + })) + .unwrap(); + let usage_chunk: crate::gateway::types::openai::ChatCompletionChunk = + serde_json::from_value(json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [], + "usage": { + "prompt_tokens": 7, + "completion_tokens": 9, + "total_tokens": 16, + "prompt_tokens_details": {"cached_tokens": 2} + } + })) + .unwrap(); + + let events = AnthropicMessagesFormat::from_hub_stream( + &first_chunk, + &mut state, + &BridgeContext::default(), + ) .unwrap(); + assert!(matches!( + &events[0], + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } + if message.id == "chatcmpl-123" + && message.usage.input_tokens.is_none() + && message.usage.output_tokens.is_none() + )); + assert!(matches!( + &events[1], + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } + if *index == 0 + )); + assert!(matches!( + &events[2], + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } + if *index == 0 + && matches!(delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") + )); - let result = AnthropicMessagesFormat::to_hub(&request); + assert!( + AnthropicMessagesFormat::from_hub_stream( + &usage_chunk, + &mut state, + &BridgeContext::default(), + ) + .unwrap() + .is_empty() + ); + + let end_events = + AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default()); assert!(matches!( - result, - Err(GatewayError::Bridge(message)) - if message.contains("hub streaming bridge") + &end_events[0], + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } + if *index == 0 + )); + assert!(matches!( + &end_events[1], + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, usage } + if delta.stop_reason.is_none() + && usage.input_tokens == Some(5) + && usage.output_tokens == Some(9) + && usage.cache_creation_input_tokens == Some(0) + && usage.cache_read_input_tokens == Some(2) + )); + assert!(matches!( + &end_events[2], + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop )); } @@ -753,6 +1212,7 @@ mod tests { assert_eq!(bridged.usage.input_tokens, 10); assert_eq!(bridged.usage.output_tokens, 7); assert_eq!(bridged.usage.cache_read_input_tokens, 2); + assert!(bridged.usage.cache_creation.is_none()); assert!(matches!( &bridged.content[0], crate::gateway::types::anthropic::AnthropicContentBlock::Text { text, .. } @@ -766,24 +1226,203 @@ mod tests { } #[test] - fn hub_stream_bridge_is_not_implemented_yet() { - let chunk: crate::gateway::types::openai::ChatCompletionChunk = + fn from_hub_stream_maps_tool_use_deltas() { + let mut state = super::AnthropicBridgeState::default(); + let first_chunk: crate::gateway::types::openai::ChatCompletionChunk = serde_json::from_value(json!({ "id": "chatcmpl-123", "object": "chat.completion.chunk", "created": 1, "model": "gpt-test", - "choices": [] + "choices": [{ + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [{ + "index": 0, + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"city\"" + } + }] + }, + "finish_reason": null + }] })) .unwrap(); + let second_chunk: crate::gateway::types::openai::ChatCompletionChunk = + serde_json::from_value(json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": 0, + "function": { + "arguments": ":\"SF\"}" + } + }] + }, + "finish_reason": "tool_calls" + }] + })) + .unwrap(); + + let first_events = AnthropicMessagesFormat::from_hub_stream( + &first_chunk, + &mut state, + &BridgeContext::default(), + ) + .unwrap(); + assert!(matches!( + &first_events[1], + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, content_block } + if *index == 0 + && matches!(content_block, crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. } if name == "get_weather") + )); + assert!(matches!( + &first_events[2], + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } + if *index == 0 + && matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == "{\"city\"") + )); + + let second_events = AnthropicMessagesFormat::from_hub_stream( + &second_chunk, + &mut state, + &BridgeContext::default(), + ) + .unwrap(); + assert!(matches!( + &second_events[0], + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } + if *index == 0 + && matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == ":\"SF\"}") + )); - let result = - AnthropicMessagesFormat::from_hub_stream(&chunk, &mut (), &BridgeContext::default()); + let end_events = + AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default()); + assert!(matches!( + &end_events[1], + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, .. } + if delta.stop_reason.as_deref() == Some("tool_use") + )); + } + #[test] + fn from_hub_stream_rejects_missing_or_non_function_tool_types() { + let mut missing_type_state = super::AnthropicBridgeState::default(); + let missing_type_chunk: crate::gateway::types::openai::ChatCompletionChunk = + serde_json::from_value(json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [{ + "index": 0, + "id": "call_1", + "function": { + "name": "get_weather", + "arguments": "{}" + } + }] + }, + "finish_reason": null + }] + })) + .unwrap(); + let missing_type_result = AnthropicMessagesFormat::from_hub_stream( + &missing_type_chunk, + &mut missing_type_state, + &BridgeContext::default(), + ); assert!(matches!( - result, + missing_type_result, Err(GatewayError::Bridge(message)) - if message.contains("hub streaming bridge") + if message.contains("requires tool call types") )); + + let mut invalid_type_state = super::AnthropicBridgeState::default(); + let invalid_type_chunk: crate::gateway::types::openai::ChatCompletionChunk = + serde_json::from_value(json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [{ + "index": 0, + "id": "call_1", + "type": "web_search", + "function": { + "name": "get_weather", + "arguments": "{}" + } + }] + }, + "finish_reason": null + }] + })) + .unwrap(); + let invalid_type_result = AnthropicMessagesFormat::from_hub_stream( + &invalid_type_chunk, + &mut invalid_type_state, + &BridgeContext::default(), + ); + assert!(matches!( + invalid_type_result, + Err(GatewayError::Bridge(message)) + if message.contains("only supports function tool calls") + )); + } + + #[test] + fn to_hub_rejects_top_level_cache_control() { + let request: AnthropicMessagesRequest = serde_json::from_value(json!({ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "cache_control": {"type": "ephemeral"}, + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + + let error = AnthropicMessagesFormat::to_hub(&request).unwrap_err(); + assert!( + matches!(error, GatewayError::Bridge(message) if message.contains("top-level cache_control")) + ); + } + + #[test] + fn to_hub_rejects_non_system_message_cache_control() { + let request: AnthropicMessagesRequest = serde_json::from_value(json!({ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "messages": [{ + "role": "user", + "content": [{ + "type": "text", + "text": "hello", + "cache_control": {"type": "ephemeral"} + }] + }] + })) + .unwrap(); + + let error = AnthropicMessagesFormat::to_hub(&request).unwrap_err(); + assert!( + matches!(error, GatewayError::Bridge(message) if message.contains("per-block cache_control")) + ); } } diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index fbdc764..d7b05b1 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -154,11 +154,9 @@ impl Gateway { let body: Value = response.json().await.map_err(GatewayError::Http)?; let response = F::parse_native_response(native, body)?; + let usage = F::response_usage(&response); - Ok(ChatResponse::Complete { - response, - usage: Usage::default(), - }) + Ok(ChatResponse::Complete { response, usage }) } async fn call_chat_hub( @@ -905,22 +903,49 @@ mod tests { } #[tokio::test] - async fn messages_reject_hub_streaming_before_dispatch() { - let request_count = Arc::new(AtomicUsize::new(0)); - let request_count_clone = Arc::clone(&request_count); + async fn messages_stream_hub_chunks_into_anthropic_events_and_usage() { + let sse_body = format!( + "data: {}\n\ndata: {}\n\ndata: [DONE]\n\n", + serde_json::to_string(&json!({ + "id": "chatcmpl-789", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "role": "assistant", + "content": "hello" + }, + "finish_reason": null + }] + })) + .unwrap(), + serde_json::to_string(&json!({ + "id": "chatcmpl-789", + "object": "chat.completion.chunk", + "created": 1, + "model": "gpt-test", + "choices": [], + "usage": { + "prompt_tokens": 7, + "completion_tokens": 9, + "total_tokens": 16, + "prompt_tokens_details": {"cached_tokens": 2} + } + })) + .unwrap(), + ); let router = Router::new().route( "/v1/chat/completions", post(move || { - let request_count = Arc::clone(&request_count_clone); + let sse_body = sse_body.clone(); async move { - request_count.fetch_add(1, Ordering::SeqCst); - Json(json!({ - "id": "chatcmpl-789", - "object": "chat.completion", - "created": 1, - "model": "gpt-test", - "choices": [] - })) + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "text/event-stream") + .body(axum::body::Body::from(sse_body)) + .unwrap() } }), ); @@ -941,13 +966,62 @@ mod tests { })) .unwrap(); - let result = gateway.messages(&request, &instance).await; + let response = gateway.messages(&request, &instance).await.unwrap(); + let ChatResponse::Stream { + mut stream, + usage_rx, + } = response + else { + panic!("expected streaming response") + }; + + let message_start = stream.next().await.unwrap().unwrap(); + let block_start = stream.next().await.unwrap().unwrap(); + let block_delta = stream.next().await.unwrap().unwrap(); + let block_stop = stream.next().await.unwrap().unwrap(); + let message_delta = stream.next().await.unwrap().unwrap(); + let message_stop = stream.next().await.unwrap().unwrap(); + assert!(stream.next().await.is_none()); + assert!(matches!( - result, - Err(GatewayError::Bridge(message)) - if message.contains("hub streaming bridge") + message_start, + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } + if message.id == "chatcmpl-789" + )); + assert!(matches!( + block_start, + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } + if index == 0 + )); + assert!(matches!( + block_delta, + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } + if index == 0 + && matches!(&delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") + )); + assert!(matches!( + block_stop, + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } + if index == 0 + )); + assert!(matches!( + message_delta, + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { usage, .. } + if usage.input_tokens == Some(5) + && usage.output_tokens == Some(9) + && usage.cache_creation_input_tokens == Some(0) + && usage.cache_read_input_tokens == Some(2) + )); + assert!(matches!( + message_stop, + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop )); - assert_eq!(request_count.load(Ordering::SeqCst), 0); + + let usage = usage_rx.await.unwrap(); + assert_eq!(usage.input_tokens, Some(7)); + assert_eq!(usage.output_tokens, Some(9)); + assert_eq!(usage.total_tokens, Some(16)); + assert_eq!(usage.cache_read_input_tokens, Some(2)); server.abort(); } @@ -977,7 +1051,9 @@ mod tests { "stop_sequence": null, "usage": { "input_tokens": 3, - "output_tokens": 4 + "output_tokens": 4, + "cache_creation_input_tokens": 5, + "cache_read_input_tokens": 2 } })) } @@ -1011,8 +1087,11 @@ mod tests { &response.content[0], AnthropicContentBlock::Text { text, .. } if text == "hello from native" )); - assert!(usage.input_tokens.is_none()); - assert!(usage.output_tokens.is_none()); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(4)); + assert_eq!(usage.total_tokens, Some(14)); + assert_eq!(usage.cache_creation_input_tokens, Some(5)); + assert_eq!(usage.cache_read_input_tokens, Some(2)); let observed = observed.lock().await.take().unwrap(); assert_eq!(observed.0.as_deref(), Some("anthropic-secret")); @@ -1022,6 +1101,113 @@ mod tests { server.abort(); } + #[tokio::test] + async fn messages_stream_native_anthropic_reports_cache_usage() { + let sse_body = concat!( + "event: message_start\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-3-5-sonnet-20241022\",\"usage\":{\"input_tokens\":3,\"output_tokens\":1,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":2}}}\n\n", + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n", + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello\"}}\n\n", + "event: content_block_stop\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "event: message_delta\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":4,\"input_tokens\":3}}\n\n", + "event: message_stop\n", + "data: {\"type\":\"message_stop\"}\n\n" + ); + let router = Router::new().route( + "/v1/messages", + post(move || async move { + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "text/event-stream") + .body(axum::body::Body::from(sse_body)) + .unwrap() + }), + ); + let (base_url, server) = spawn_server(router).await; + + let gateway = Gateway::new(ProviderRegistry::builder().build()); + let instance = ProviderInstance { + def: Arc::new(AnthropicDef), + auth: ProviderAuth::ApiKey("anthropic-secret".into()), + base_url_override: Some(base_url), + custom_headers: HeaderMap::new(), + }; + let request: AnthropicMessagesRequest = serde_json::from_value(json!({ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "messages": [{"role": "user", "content": "hello"}], + "stream": true + })) + .unwrap(); + + let response = gateway.messages(&request, &instance).await.unwrap(); + let ChatResponse::Stream { + mut stream, + usage_rx, + } = response + else { + panic!("expected stream response") + }; + + let message_start = stream.next().await.unwrap().unwrap(); + let block_start = stream.next().await.unwrap().unwrap(); + let block_delta = stream.next().await.unwrap().unwrap(); + let block_stop = stream.next().await.unwrap().unwrap(); + let message_delta = stream.next().await.unwrap().unwrap(); + let message_stop = stream.next().await.unwrap().unwrap(); + assert!(stream.next().await.is_none()); + + assert!(matches!( + message_start, + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } + if message.usage.input_tokens == Some(3) + && message.usage.output_tokens == Some(1) + && message.usage.cache_creation_input_tokens == Some(5) + && message.usage.cache_read_input_tokens == Some(2) + )); + assert!(matches!( + block_start, + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } + if index == 0 + )); + assert!(matches!( + block_delta, + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } + if index == 0 + && matches!(&delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") + )); + assert!(matches!( + block_stop, + crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } + if index == 0 + )); + assert!(matches!( + message_delta, + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { usage, .. } + if usage.input_tokens == Some(3) + && usage.output_tokens == Some(4) + && usage.cache_creation_input_tokens.is_none() + && usage.cache_read_input_tokens.is_none() + )); + assert!(matches!( + message_stop, + crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop + )); + + let usage = usage_rx.await.unwrap(); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(4)); + assert_eq!(usage.total_tokens, Some(14)); + assert_eq!(usage.cache_creation_input_tokens, Some(5)); + assert_eq!(usage.cache_read_input_tokens, Some(2)); + + server.abort(); + } + #[tokio::test] async fn chat_completion_streams_hub_chunks_and_reports_usage() { let sse_body = format!( diff --git a/src/gateway/providers/anthropic/mod.rs b/src/gateway/providers/anthropic/mod.rs index a5e973a..4247b3a 100644 --- a/src/gateway/providers/anthropic/mod.rs +++ b/src/gateway/providers/anthropic/mod.rs @@ -164,7 +164,7 @@ mod tests { let events = provider .transform_anthropic_messages_stream_chunk( r#"data: {"type":"ping"}"#, - &mut AnthropicMessagesNativeStreamState, + &mut AnthropicMessagesNativeStreamState::default(), ) .unwrap(); diff --git a/src/gateway/providers/anthropic/transform.rs b/src/gateway/providers/anthropic/transform.rs index fac23c8..171995b 100644 --- a/src/gateway/providers/anthropic/transform.rs +++ b/src/gateway/providers/anthropic/transform.rs @@ -10,8 +10,8 @@ use crate::gateway::{ anthropic::{ AnthropicContent, AnthropicContentBlock, AnthropicMessage, AnthropicMessagesRequest, AnthropicMessagesResponse, AnthropicMetadata, AnthropicStreamEvent, AnthropicTool, - AnthropicToolChoice, AnthropicUsage, ContentDelta, ImageSource, SystemBlock, - SystemPrompt, + AnthropicToolChoice, AnthropicUsage, ContentDelta, DeltaUsage, ImageSource, + MessageStartUsage, SystemBlock, SystemPrompt, }, openai::{ ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, @@ -60,6 +60,7 @@ pub(crate) fn openai_to_anthropic_request( .max_tokens .or(request.max_completion_tokens) .unwrap_or(DEFAULT_MAX_TOKENS), + cache_control: None, system, temperature: request.temperature, top_p: request.top_p, @@ -108,7 +109,7 @@ pub(crate) fn parse_anthropic_sse_to_openai( state.response_id = Some(message.id.clone()); state.response_model = Some(message.model.clone()); state.response_created = Some(now_unix_secs()); - state.input_tokens = Some(message.usage.input_tokens); + apply_anthropic_message_start_usage_to_stream_state(state, &message.usage); Ok(vec![ChatCompletionChunk { id: message.id, @@ -132,7 +133,9 @@ pub(crate) fn parse_anthropic_sse_to_openai( index, content_block, } => match content_block { - AnthropicContentBlock::ToolUse { id, name, input } => { + AnthropicContentBlock::ToolUse { + id, name, input, .. + } => { let initial_arguments = initial_tool_arguments(&input)?; let accumulator = state .tool_call_accumulators @@ -206,14 +209,16 @@ pub(crate) fn parse_anthropic_sse_to_openai( } }, AnthropicStreamEvent::MessageDelta { delta, usage } => { - state.output_tokens = Some(usage.output_tokens); - if usage.input_tokens > 0 { - state.input_tokens = Some(usage.input_tokens); - } + apply_anthropic_delta_usage_to_stream_state(state, &usage); let usage = match (state.input_tokens, state.output_tokens) { (Some(input_tokens), Some(output_tokens)) => { - Some(stream_usage_to_openai_usage(input_tokens, output_tokens)) + Some(stream_usage_to_openai_usage_with_cached( + input_tokens, + output_tokens, + state.cache_creation_input_tokens.unwrap_or(0) + + state.cache_read_input_tokens.unwrap_or(0), + )) } _ => None, }; @@ -333,6 +338,7 @@ fn openai_assistant_message_to_anthropic( error )) })?, + cache_control: None, }); } } @@ -365,6 +371,7 @@ fn openai_tool_message_to_anthropic( None => None, }, is_error: None, + cache_control: None, }]), }) } @@ -402,6 +409,7 @@ fn content_to_anthropic_blocks( }), ContentPart::ImageUrl { image_url } => Ok(AnthropicContentBlock::Image { source: image_url_to_source(&image_url.url)?, + cache_control: None, }), }) .collect(), @@ -489,7 +497,7 @@ fn anthropic_blocks_to_openai_message( text_segments.push(text.clone()); rich_parts.push(ContentPart::Text { text: text.clone() }); } - AnthropicContentBlock::Image { source } => { + AnthropicContentBlock::Image { source, .. } => { has_non_text_part = true; rich_parts.push(ContentPart::ImageUrl { image_url: ImageUrl { @@ -498,7 +506,9 @@ fn anthropic_blocks_to_openai_message( }, }); } - AnthropicContentBlock::ToolUse { id, name, input } => { + AnthropicContentBlock::ToolUse { + id, name, input, .. + } => { tool_calls.push(ToolCall { id: id.clone(), r#type: "function".into(), @@ -553,16 +563,64 @@ fn anthropic_usage_to_openai_usage(usage: &AnthropicUsage) -> ChatCompletionUsag } } -fn stream_usage_to_openai_usage(input_tokens: u32, output_tokens: u32) -> ChatCompletionUsage { +fn stream_usage_to_openai_usage_with_cached( + input_tokens: u32, + output_tokens: u32, + cached_tokens: u32, +) -> ChatCompletionUsage { ChatCompletionUsage { prompt_tokens: input_tokens, completion_tokens: output_tokens, total_tokens: input_tokens + output_tokens, - prompt_tokens_details: None, + prompt_tokens_details: (cached_tokens > 0).then_some(PromptTokensDetails { + cached_tokens: Some(cached_tokens), + audio_tokens: None, + }), completion_tokens_details: None, } } +fn apply_anthropic_message_start_usage_to_stream_state( + state: &mut ChatStreamState, + usage: &MessageStartUsage, +) { + if let Some(input_tokens) = usage.input_tokens { + state.input_tokens = Some( + input_tokens + + usage.cache_creation_input_tokens.unwrap_or(0) + + usage.cache_read_input_tokens.unwrap_or(0), + ); + } + if let Some(output_tokens) = usage.output_tokens { + state.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens { + state.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = usage.cache_read_input_tokens { + state.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn apply_anthropic_delta_usage_to_stream_state(state: &mut ChatStreamState, usage: &DeltaUsage) { + if let Some(input_tokens) = usage.input_tokens { + state.input_tokens = Some( + input_tokens + + state.cache_creation_input_tokens.unwrap_or(0) + + state.cache_read_input_tokens.unwrap_or(0), + ); + } + if let Some(output_tokens) = usage.output_tokens { + state.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens { + state.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = usage.cache_read_input_tokens { + state.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + fn map_anthropic_stop_reason(stop_reason: Option<&str>) -> Option { stop_reason.map(|reason| match reason { "end_turn" | "stop_sequence" => "stop".into(), @@ -797,7 +855,7 @@ mod tests { let mut state = ChatStreamState::default(); let start = parse_anthropic_sse_to_openai( - r#"data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":7}}}"#, + r#"data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","usage":{"input_tokens":7,"output_tokens":1,"cache_creation_input_tokens":3,"cache_read_input_tokens":2}}}"#, &mut state, ) .unwrap(); @@ -838,7 +896,19 @@ mod tests { finish[0].choices[0].finish_reason.as_deref(), Some("tool_calls") ); - assert_eq!(finish[0].usage.as_ref().unwrap().total_tokens, 18); + assert_eq!(finish[0].usage.as_ref().unwrap().prompt_tokens, 12); + assert_eq!(finish[0].usage.as_ref().unwrap().total_tokens, 23); + assert_eq!( + finish[0] + .usage + .as_ref() + .unwrap() + .prompt_tokens_details + .as_ref() + .unwrap() + .cached_tokens, + Some(5) + ); } #[test] diff --git a/src/gateway/streams/bridged.rs b/src/gateway/streams/bridged.rs index e52f38e..5a14825 100644 --- a/src/gateway/streams/bridged.rs +++ b/src/gateway/streams/bridged.rs @@ -62,6 +62,8 @@ impl BridgedStream { (Some(input_tokens), Some(output_tokens)) => Some(input_tokens + output_tokens), _ => None, }, + cache_creation_input_tokens: state.cache_creation_input_tokens, + cache_read_input_tokens: state.cache_read_input_tokens, ..Default::default() } } diff --git a/src/gateway/streams/hub.rs b/src/gateway/streams/hub.rs index 9cbdeea..d500740 100644 --- a/src/gateway/streams/hub.rs +++ b/src/gateway/streams/hub.rs @@ -78,6 +78,14 @@ impl Stream for HubChunkStream { if let Some(usage) = &chunk.usage { this.state.input_tokens = Some(usage.prompt_tokens); this.state.output_tokens = Some(usage.completion_tokens); + if this.state.cache_creation_input_tokens.is_none() + && this.state.cache_read_input_tokens.is_none() + { + this.state.cache_read_input_tokens = usage + .prompt_tokens_details + .as_ref() + .and_then(|details| details.cached_tokens); + } } } diff --git a/src/gateway/traits/chat_format.rs b/src/gateway/traits/chat_format.rs index 0d0e7cc..7d4bc0d 100644 --- a/src/gateway/traits/chat_format.rs +++ b/src/gateway/traits/chat_format.rs @@ -90,6 +90,11 @@ pub trait ChatFormat: Send + Sync + 'static { Usage::default() } + /// Extract usage from a native non-streaming response. + fn response_usage(_response: &Self::Response) -> Usage { + Usage::default() + } + /// Parse a native non-streaming response into this format. fn parse_native_response(native: &NativeHandler<'_>, body: Value) -> Result where @@ -133,6 +138,8 @@ pub struct ChatStreamState { pub response_created: Option, pub input_tokens: Option, pub output_tokens: Option, + pub cache_creation_input_tokens: Option, + pub cache_read_input_tokens: Option, } #[cfg(test)] diff --git a/src/gateway/traits/native.rs b/src/gateway/traits/native.rs index bf763f2..82970a9 100644 --- a/src/gateway/traits/native.rs +++ b/src/gateway/traits/native.rs @@ -7,13 +7,16 @@ use crate::gateway::{ traits::provider::ChatTransform, types::{ anthropic::{AnthropicMessagesRequest, AnthropicMessagesResponse, AnthropicStreamEvent}, + common::Usage, openai::responses::{ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent}, }, }; /// Stateful data for native Anthropic Messages streaming transforms. #[derive(Debug, Clone, Default)] -pub struct AnthropicMessagesNativeStreamState; +pub struct AnthropicMessagesNativeStreamState { + pub usage: Usage, +} /// Stateful data for native OpenAI Responses streaming transforms. #[derive(Debug, Clone, Default)] diff --git a/src/gateway/types/anthropic.rs b/src/gateway/types/anthropic.rs index d3c9eb2..f358d42 100644 --- a/src/gateway/types/anthropic.rs +++ b/src/gateway/types/anthropic.rs @@ -17,6 +17,9 @@ pub struct AnthropicMessagesRequest { pub messages: Vec, pub max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, @@ -66,6 +69,8 @@ pub struct SystemBlock { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct CacheControl { pub r#type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, } /// Anthropic request metadata. @@ -102,13 +107,19 @@ pub enum AnthropicContentBlock { }, #[serde(rename = "image")] - Image { source: ImageSource }, + Image { + source: ImageSource, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: Value, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, }, #[serde(rename = "tool_result")] @@ -118,6 +129,8 @@ pub enum AnthropicContentBlock { content: Option, #[serde(skip_serializing_if = "Option::is_none")] is_error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, }, } @@ -168,12 +181,25 @@ pub struct AnthropicMessagesResponse { /// Anthropic usage metrics. #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct AnthropicUsage { + #[serde(default)] pub input_tokens: u32, + #[serde(default)] pub output_tokens: u32, #[serde(default)] pub cache_creation_input_tokens: u32, #[serde(default)] pub cache_read_input_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_creation: Option, +} + +/// Optional detailed cache creation breakdown returned by newer Claude APIs. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct AnthropicCacheCreation { + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral_5m_input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral_1h_input_tokens: Option, } // ── Streaming event types ── @@ -240,13 +266,22 @@ pub struct MessageStartPayload { pub r#type: String, pub role: String, pub model: String, - pub usage: InputUsage, + pub usage: MessageStartUsage, } -/// Input usage reported at message start. +/// Usage reported at `message_start`. #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct InputUsage { - pub input_tokens: u32, +pub struct MessageStartUsage { + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_creation_input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_creation: Option, } /// Content delta within a `content_block_delta` event. @@ -271,9 +306,14 @@ pub struct MessageDelta { /// Usage reported in `message_delta`. #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct DeltaUsage { - pub output_tokens: u32, - #[serde(default)] - pub input_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_creation_input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_input_tokens: Option, } /// Anthropic API error body. @@ -339,6 +379,23 @@ mod tests { } } + #[test] + fn request_with_top_level_cache_control_and_ttl() { + let json = json!({ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + "messages": [{"role": "user", "content": "Hi"}] + }); + + let req: AnthropicMessagesRequest = serde_json::from_value(json).unwrap(); + assert_eq!(req.cache_control.as_ref().unwrap().r#type, "ephemeral"); + assert_eq!( + req.cache_control.as_ref().unwrap().ttl.as_deref(), + Some("1h") + ); + } + #[test] fn request_with_tools() { let json = json!({ @@ -419,7 +476,8 @@ mod tests { assert_eq!(event.event_type(), "message_start"); if let AnthropicStreamEvent::MessageStart { message } = &event { assert_eq!(message.id, "msg_123"); - assert_eq!(message.usage.input_tokens, 25); + assert_eq!(message.usage.input_tokens, Some(25)); + assert!(message.usage.output_tokens.is_none()); } else { panic!("Expected MessageStart"); } @@ -465,7 +523,8 @@ mod tests { assert_eq!(event.event_type(), "message_delta"); if let AnthropicStreamEvent::MessageDelta { delta, usage } = &event { assert_eq!(delta.stop_reason.as_deref(), Some("end_turn")); - assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.output_tokens, Some(50)); + assert_eq!(usage.input_tokens, Some(0)); } else { panic!("Expected MessageDelta"); }