Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/crates/ai-adapters/src/stream/stream_handler/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ pub async fn handle_anthropic_stream(
Err(e) => {
stats.increment("error:sse_parsing");
let err_str = format!("SSE Parsing Error: {e}, data: {}", &data);
stats.log_summary("sse_parsing_error");
error!("{}", err_str);
continue;
let _ = tx_event.send(Err(anyhow!(err_str)));
return;
}
};
// Emit for Thinking and ToolUse content_block_start events.
Expand All @@ -128,8 +130,10 @@ pub async fn handle_anthropic_stream(
Err(e) => {
stats.increment("error:sse_parsing");
let err_str = format!("SSE Parsing Error: {e}, data: {}", &data);
stats.log_summary("sse_parsing_error");
error!("{}", err_str);
continue;
let _ = tx_event.send(Err(anyhow!(err_str)));
return;
}
};
match UnifiedResponse::try_from(content_block_delta) {
Expand Down
146 changes: 110 additions & 36 deletions src/crates/ai-adapters/src/stream/stream_handler/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,54 @@ fn cleanup_tool_call_tracking(
}
}

fn handle_function_call_arguments_delta(
tx_event: &mpsc::UnboundedSender<Result<UnifiedResponse>>,
stats: &mut StreamStats,
output_index: Option<usize>,
delta: Option<String>,
tool_calls_by_output_index: &mut HashMap<usize, InProgressToolCall>,
) -> Result<()> {
let Some(delta) = delta.filter(|delta| !delta.is_empty()) else {
return Ok(());
};
let Some(output_index) = output_index else {
return Err(anyhow!(
"Responses function_call_arguments.delta missing output_index"
));
};
let Some(tc) = tool_calls_by_output_index.get_mut(&output_index) else {
return Err(anyhow!(
"Responses function_call_arguments.delta for untracked output_index {}",
output_index
));
};

tc.saw_any_delta = true;
tc.args_so_far.push_str(&delta);

// Some consumers treat `id` as a "new tool call" marker and reset buffers when it repeats.
// Only send id/name once per tool call; deltas that follow carry arguments only.
let (id, name) = if tc.sent_header {
(None, None)
} else {
tc.sent_header = true;
(tc.call_id.clone(), tc.name.clone())
};

let unified_response = UnifiedResponse {
tool_call: Some(crate::stream::types::unified::UnifiedToolCall {
tool_call_index: Some(output_index),
id,
name,
arguments: Some(delta),
arguments_is_snapshot: false,
}),
..Default::default()
};
emit_unified_response(tx_event, stats, unified_response);
Ok(())
}

fn handle_function_call_output_item_done(
tx_event: &mpsc::UnboundedSender<Result<UnifiedResponse>>,
stats: &mut StreamStats,
Expand Down Expand Up @@ -272,9 +320,17 @@ pub async fn handle_responses_stream(
match event.kind.as_str() {
"response.output_item.added" => {
// Track tool calls so we can stream arguments via `response.function_call_arguments.delta`.
if let (Some(output_index), Some(item)) = (event.output_index, event.item.as_ref())
{
if let Some(item) = event.item.as_ref() {
if let Some(tc) = InProgressToolCall::from_item_value(item) {
let Some(output_index) = event.output_index else {
let error_msg =
"Responses function_call output_item.added missing output_index";
stats.increment("error:missing_output_index");
stats.log_summary("responses_tool_call_missing_output_index");
error!("{}", error_msg);
let _ = tx_event.send(Err(anyhow!(error_msg)));
return;
};
if let Some(ref call_id) = tc.call_id {
tool_call_index_by_id.insert(call_id.clone(), output_index);
}
Expand Down Expand Up @@ -302,39 +358,20 @@ pub async fn handle_responses_stream(
}
}
"response.function_call_arguments.delta" => {
let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) else {
continue;
};
let Some(output_index) = event.output_index else {
continue;
};
let Some(tc) = tool_calls_by_output_index.get_mut(&output_index) else {
continue;
};

tc.saw_any_delta = true;
tc.args_so_far.push_str(&delta);

// Some consumers treat `id` as a "new tool call" marker and reset buffers when it repeats.
// Only send id/name once per tool call; deltas that follow carry arguments only.
let (id, name) = if tc.sent_header {
(None, None)
} else {
tc.sent_header = true;
(tc.call_id.clone(), tc.name.clone())
};

let unified_response = UnifiedResponse {
tool_call: Some(crate::stream::types::unified::UnifiedToolCall {
tool_call_index: Some(output_index),
id,
name,
arguments: Some(delta),
arguments_is_snapshot: false,
}),
..Default::default()
};
emit_unified_response(&tx_event, &mut stats, unified_response);
if let Err(err) = handle_function_call_arguments_delta(
&tx_event,
&mut stats,
event.output_index,
event.delta,
&mut tool_calls_by_output_index,
) {
let error_msg = err.to_string();
stats.increment("error:function_call_arguments_delta");
stats.log_summary("responses_function_call_arguments_delta_error");
error!("{}", error_msg);
let _ = tx_event.send(Err(anyhow!(error_msg)));
return;
}
}
"response.output_item.done" => {
let Some(item_value) = event.item else {
Expand Down Expand Up @@ -540,7 +577,8 @@ pub async fn handle_responses_stream(
mod tests {
use super::{
super::stream_stats::StreamStats, extract_api_error_message,
handle_function_call_output_item_done, InProgressToolCall,
handle_function_call_arguments_delta, handle_function_call_output_item_done,
InProgressToolCall,
};
use serde_json::json;
use std::collections::HashMap;
Expand Down Expand Up @@ -622,4 +660,40 @@ mod tests {
Some("{\"city\":\"Beijing\"}")
);
}

#[test]
fn function_call_delta_requires_output_index() {
let (tx_event, _rx_event) = mpsc::unbounded_channel();
let mut tool_calls_by_output_index: HashMap<usize, InProgressToolCall> = HashMap::new();
let mut stats = StreamStats::new("Responses");

let err = handle_function_call_arguments_delta(
&tx_event,
&mut stats,
None,
Some("{\"city\"".to_string()),
&mut tool_calls_by_output_index,
)
.expect_err("missing output_index should fail");

assert!(err.to_string().contains("missing output_index"));
}

#[test]
fn function_call_delta_requires_tracked_output_item() {
let (tx_event, _rx_event) = mpsc::unbounded_channel();
let mut tool_calls_by_output_index: HashMap<usize, InProgressToolCall> = HashMap::new();
let mut stats = StreamStats::new("Responses");

let err = handle_function_call_arguments_delta(
&tx_event,
&mut stats,
Some(2),
Some("{\"city\"".to_string()),
&mut tool_calls_by_output_index,
)
.expect_err("untracked output_index should fail");

assert!(err.to_string().contains("untracked output_index 2"));
}
}
6 changes: 4 additions & 2 deletions src/crates/ai-adapters/src/stream/types/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ impl From<MessageDelta> for UnifiedResponse {

#[derive(Debug, Deserialize)]
pub struct ContentBlockStart {
pub index: Option<usize>,
pub content_block: ContentBlock,
}

Expand All @@ -118,7 +119,7 @@ impl From<ContentBlockStart> for UnifiedResponse {
match value.content_block {
ContentBlock::ToolUse { id, name } => {
let tool_call = UnifiedToolCall {
tool_call_index: None,
tool_call_index: value.index,
id: Some(id),
name: Some(name),
arguments: None,
Expand All @@ -141,6 +142,7 @@ impl From<ContentBlockStart> for UnifiedResponse {

#[derive(Debug, Deserialize)]
pub struct ContentBlockDelta {
index: Option<usize>,
delta: Delta,
}

Expand Down Expand Up @@ -172,7 +174,7 @@ impl TryFrom<ContentBlockDelta> for UnifiedResponse {
}
Delta::InputJson { partial_json } => {
let tool_call = UnifiedToolCall {
tool_call_index: None,
tool_call_index: value.index,
id: None,
name: None,
arguments: Some(partial_json),
Expand Down
Loading
Loading