diff --git a/BitFun-Installer/src-tauri/src/connection_test/client.rs b/BitFun-Installer/src-tauri/src/connection_test/client.rs index 345035579..3be23fca3 100644 --- a/BitFun-Installer/src-tauri/src/connection_test/client.rs +++ b/BitFun-Installer/src-tauri/src/connection_test/client.rs @@ -2097,7 +2097,6 @@ mod tests { temperature: None, top_p: None, enable_thinking_process: false, - support_preserved_thinking: false, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, @@ -2121,7 +2120,6 @@ mod tests { temperature: None, top_p: None, enable_thinking_process: false, - support_preserved_thinking: false, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, @@ -2150,7 +2148,6 @@ mod tests { temperature: None, top_p: None, enable_thinking_process: false, - support_preserved_thinking: false, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, @@ -2180,7 +2177,6 @@ mod tests { temperature: Some(0.2), top_p: Some(0.8), enable_thinking_process: true, - support_preserved_thinking: true, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, @@ -2259,7 +2255,6 @@ mod tests { temperature: None, top_p: None, enable_thinking_process: false, - support_preserved_thinking: true, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/config.rs b/BitFun-Installer/src-tauri/src/connection_test/types/config.rs index ab45cbcb3..87ec6cbe9 100644 --- a/BitFun-Installer/src-tauri/src/connection_test/types/config.rs +++ b/BitFun-Installer/src-tauri/src/connection_test/types/config.rs @@ -82,7 +82,6 @@ pub struct AIConfig { pub temperature: Option, pub top_p: Option, pub enable_thinking_process: bool, - pub support_preserved_thinking: bool, pub inline_think_in_text: bool, pub custom_headers: Option>, /// "replace" (default) or "merge" (defaults first, then custom) diff --git a/BitFun-Installer/src-tauri/src/installer/ai_config.rs b/BitFun-Installer/src-tauri/src/installer/ai_config.rs index 38c6719ba..63730fad1 100644 --- a/BitFun-Installer/src-tauri/src/installer/ai_config.rs +++ b/BitFun-Installer/src-tauri/src/installer/ai_config.rs @@ -48,7 +48,6 @@ pub fn ai_config_from_installer_model(m: &ModelConfig) -> Result Result<(), String> { model_map.insert("recommended_for".to_string(), Value::Array(Vec::new())); model_map.insert("metadata".to_string(), Value::Null); model_map.insert("enable_thinking_process".to_string(), Value::Bool(false)); - model_map.insert("support_preserved_thinking".to_string(), Value::Bool(false)); model_map.insert("inline_think_in_text".to_string(), Value::Bool(false)); if let Some(skip_ssl_verify) = model.skip_ssl_verify { diff --git a/src/crates/core/build.rs b/src/crates/core/build.rs index 6c8b9bcbc..d75f42959 100644 --- a/src/crates/core/build.rs +++ b/src/crates/core/build.rs @@ -248,10 +248,7 @@ fn embed_announcement_content() -> Result<(), Box> { .join("announcement") .join("content"); - println!( - "cargo:rerun-if-changed={}", - content_root.display() - ); + println!("cargo:rerun-if-changed={}", content_root.display()); emit_rerun_if_changed(&content_root); let mut entries: HashMap = HashMap::new(); @@ -282,7 +279,10 @@ fn embed_announcement_content() -> Result<(), Box> { .to_string_lossy() .to_string(); // Strip leading numeric prefix (e.g. "001_") to use bare id as key. - let key_stem = if stem.len() > 4 && stem.chars().take(3).all(|c| c.is_ascii_digit()) && stem.chars().nth(3) == Some('_') { + let key_stem = if stem.len() > 4 + && stem.chars().take(3).all(|c| c.is_ascii_digit()) + && stem.chars().nth(3) == Some('_') + { stem[4..].to_string() } else { stem @@ -308,7 +308,10 @@ fn generate_embedded_announcements_code( let dest_path = Path::new(&out_dir).join("embedded_announcements.rs"); let mut file = fs::File::create(&dest_path)?; - writeln!(file, "// Embedded announcement content (auto-generated by build.rs)")?; + writeln!( + file, + "// Embedded announcement content (auto-generated by build.rs)" + )?; writeln!(file, "// Do not edit manually.")?; writeln!(file)?; writeln!(file, "use std::collections::HashMap;")?; diff --git a/src/crates/core/src/agentic/agents/prompt_builder/prompt_builder_impl.rs b/src/crates/core/src/agentic/agents/prompt_builder/prompt_builder_impl.rs index 82410bc0e..5b92cff84 100644 --- a/src/crates/core/src/agentic/agents/prompt_builder/prompt_builder_impl.rs +++ b/src/crates/core/src/agentic/agents/prompt_builder/prompt_builder_impl.rs @@ -5,8 +5,8 @@ use crate::service::ai_memory::AIMemoryManager; use crate::service::ai_rules::get_global_ai_rules_service; use crate::service::bootstrap::build_workspace_persona_prompt; use crate::service::config::get_app_language_code; -use crate::service::filesystem::get_formatted_directory_listing; use crate::service::config::global::GlobalConfigManager; +use crate::service::filesystem::get_formatted_directory_listing; use crate::service::project_context::ProjectContextService; use crate::util::errors::{BitFunError, BitFunResult}; use log::{debug, warn}; diff --git a/src/crates/core/src/agentic/coordination/coordinator.rs b/src/crates/core/src/agentic/coordination/coordinator.rs index 386d202ce..3e1d90dce 100644 --- a/src/crates/core/src/agentic/coordination/coordinator.rs +++ b/src/crates/core/src/agentic/coordination/coordinator.rs @@ -988,22 +988,17 @@ Update the persona files and delete BOOTSTRAP.md as soon as bootstrap is complet let session_max_tokens = session.config.max_context_tokens; // Unify context_window: min(model capability, session config) - let model_context_window = match crate::infrastructure::ai::get_global_ai_client_factory() - .await - { - Ok(factory) => { - let model_id = session - .config - .model_id - .as_deref() - .unwrap_or("default"); - match factory.get_client_resolved(model_id).await { - Ok(client) => Some(client.config.context_window as usize), - Err(_) => None, + let model_context_window = + match crate::infrastructure::ai::get_global_ai_client_factory().await { + Ok(factory) => { + let model_id = session.config.model_id.as_deref().unwrap_or("default"); + match factory.get_client_resolved(model_id).await { + Ok(client) => Some(client.config.context_window as usize), + Err(_) => None, + } } - } - Err(_) => None, - }; + Err(_) => None, + }; let context_window = match model_context_window { Some(mcw) => mcw.min(session_max_tokens), None => session_max_tokens, diff --git a/src/crates/core/src/agentic/core/message.rs b/src/crates/core/src/agentic/core/message.rs index 955f5a1ea..5655f1d35 100644 --- a/src/crates/core/src/agentic/core/message.rs +++ b/src/crates/core/src/agentic/core/message.rs @@ -56,8 +56,6 @@ pub struct MessageMetadata { pub turn_id: Option, pub round_id: Option, pub tokens: Option, - #[serde(skip)] // Not serialized, auxiliary field for runtime use only - pub keep_thinking: bool, /// Anthropic extended thinking signature (for passing back in multi-turn conversations) #[serde(skip_serializing_if = "Option::is_none")] pub thinking_signature: Option, @@ -168,7 +166,6 @@ impl From for AIMessage { MessageRole::Tool => "tool", MessageRole::System => "system", }; - let keep_thinking = msg.metadata.keep_thinking; let thinking_signature = msg.metadata.thinking_signature.clone(); match msg.content { @@ -273,16 +270,10 @@ impl From for AIMessage { }; // Reasoning content (interleaved thinking mode) - let reasoning = if keep_thinking { - reasoning_content.filter(|r| !r.is_empty()) - } else { - None - }; - Self { role: "assistant".to_string(), content, - reasoning_content: reasoning, + reasoning_content: reasoning_content.filter(|r| !r.is_empty()), thinking_signature: thinking_signature.clone(), tool_calls: converted_tool_calls, tool_call_id: None, @@ -506,7 +497,7 @@ impl Message { 50 + tiles * 200 } - pub fn estimate_tokens(&self) -> usize { + pub fn estimate_tokens_with_reasoning(&self, include_reasoning: bool) -> usize { let mut total = 0usize; total += 4; @@ -525,7 +516,7 @@ impl Message { text, tool_calls, } => { - if self.metadata.keep_thinking { + if include_reasoning { if let Some(reasoning) = reasoning_content.as_ref() { total += TokenCounter::estimate_tokens(reasoning); } @@ -564,6 +555,10 @@ impl Message { total } + + fn estimate_tokens(&self) -> usize { + self.estimate_tokens_with_reasoning(true) + } } impl Display for MessageContent { diff --git a/src/crates/core/src/agentic/core/messages_helper.rs b/src/crates/core/src/agentic/core/messages_helper.rs index 281af38ed..371101da4 100644 --- a/src/crates/core/src/agentic/core/messages_helper.rs +++ b/src/crates/core/src/agentic/core/messages_helper.rs @@ -1,75 +1,71 @@ use super::{CompressedTodoItem, CompressedTodoSnapshot, Message, MessageContent, MessageRole}; +use crate::util::token_counter::TokenCounter; use crate::util::types::Message as AIMessage; -use log::warn; +use crate::util::types::ToolDefinition; pub struct MessageHelper; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RequestReasoningTokenPolicy { + FullHistory, + LatestTurnOnly, + SkipAll, +} + impl MessageHelper { - pub fn compute_keep_thinking_flags( - messages: &mut [Message], - enable_thinking: bool, - support_preserved_thinking: bool, - ) { + pub fn convert_messages(messages: &[Message]) -> Vec { + messages.iter().map(AIMessage::from).collect() + } + + pub fn estimate_request_tokens( + messages: &[Message], + tools: Option<&[ToolDefinition]>, + reasoning_policy: RequestReasoningTokenPolicy, + ) -> usize { + let reasoning_frontier_start = match reasoning_policy { + RequestReasoningTokenPolicy::FullHistory => Some(0), + RequestReasoningTokenPolicy::LatestTurnOnly => { + Some(Self::find_reasoning_frontier_start(messages)) + } + RequestReasoningTokenPolicy::SkipAll => None, + }; + + let mut total = messages + .iter() + .enumerate() + .map(|(index, message)| { + let include_reasoning = + reasoning_frontier_start.is_some_and(|frontier_start| index >= frontier_start); + message.estimate_tokens_with_reasoning(include_reasoning) + }) + .sum::(); + + total += 3; + + if let Some(tool_defs) = tools { + total += TokenCounter::estimate_tool_definitions_tokens(tool_defs); + } + + total + } + + fn find_reasoning_frontier_start(messages: &[Message]) -> usize { if messages.is_empty() { - return; + return 0; } - if !enable_thinking { - messages.iter_mut().for_each(|m| { - if m.metadata.keep_thinking { - m.metadata.keep_thinking = false; - m.metadata.tokens = None; - } - }); - } else if support_preserved_thinking { - messages.iter_mut().for_each(|m| { - if !m.metadata.keep_thinking { - m.metadata.keep_thinking = true; - m.metadata.tokens = None; - } - }); - } else { - let last_message_turn_id = messages.last().and_then(|m| m.metadata.turn_id.clone()); - if let Some(last_turn_id) = last_message_turn_id { - messages.iter_mut().for_each(|m| { - let keep_thinking = m - .metadata - .turn_id - .as_ref() - .is_some_and(|cur_turn_id| cur_turn_id == &last_turn_id); - if m.metadata.keep_thinking != keep_thinking { - m.metadata.keep_thinking = keep_thinking; - m.metadata.tokens = None; - } - }) - } else { - // Find the last actual user-turn boundary from back to front. - let last_user_message_index = - messages.iter().rposition(|m| m.is_actual_user_message()); - if let Some(last_user_message_index) = last_user_message_index { - // Messages from the last user message onwards are messages for this turn - messages.iter_mut().enumerate().for_each(|(index, m)| { - let keep_thinking = index >= last_user_message_index; - if m.metadata.keep_thinking != keep_thinking { - m.metadata.keep_thinking = keep_thinking; - m.metadata.tokens = None; - } - }) - } else { - // No user message found, should not reach here in practice - warn!("compute_keep_thinking_flags: no user message found"); - - messages.iter_mut().for_each(|m| { - if m.metadata.keep_thinking { - m.metadata.keep_thinking = false; - m.metadata.tokens = None; - } - }); - } + + if let Some(last_turn_id) = messages.last().and_then(|m| m.metadata.turn_id.as_deref()) { + if let Some(frontier_start) = messages + .iter() + .position(|m| m.metadata.turn_id.as_deref() == Some(last_turn_id)) + { + return frontier_start; } } - } - pub fn convert_messages(messages: &[Message]) -> Vec { - messages.iter().map(AIMessage::from).collect() + messages + .iter() + .rposition(Message::is_actual_user_message) + .unwrap_or(messages.len().saturating_sub(1)) } pub fn group_messages_by_turns(mut messages: Vec) -> Vec> { @@ -194,3 +190,89 @@ impl MessageHelper { None } } + +#[cfg(test)] +mod tests { + use super::{MessageHelper, RequestReasoningTokenPolicy}; + use crate::agentic::core::Message; + use crate::util::token_counter::TokenCounter; + + #[test] + fn latest_turn_reasoning_policy_uses_turn_id_boundary() { + let messages = vec![ + Message::user("old user".to_string()).with_turn_id("turn-1".to_string()), + Message::assistant_with_reasoning( + Some("old reasoning".to_string()), + "old answer".to_string(), + Vec::new(), + ) + .with_turn_id("turn-1".to_string()), + Message::user("new user".to_string()).with_turn_id("turn-2".to_string()), + Message::assistant_with_reasoning( + Some("new reasoning".to_string()), + "new answer".to_string(), + Vec::new(), + ) + .with_turn_id("turn-2".to_string()), + ]; + + let full = MessageHelper::estimate_request_tokens( + &messages, + None, + RequestReasoningTokenPolicy::FullHistory, + ); + let latest = MessageHelper::estimate_request_tokens( + &messages, + None, + RequestReasoningTokenPolicy::LatestTurnOnly, + ); + let skip_all = MessageHelper::estimate_request_tokens( + &messages, + None, + RequestReasoningTokenPolicy::SkipAll, + ); + + assert_eq!( + full - latest, + TokenCounter::estimate_tokens("old reasoning") + ); + assert_eq!( + latest - skip_all, + TokenCounter::estimate_tokens("new reasoning") + ); + } + + #[test] + fn latest_turn_reasoning_policy_falls_back_to_last_actual_user_message() { + let messages = vec![ + Message::user("old user".to_string()), + Message::assistant_with_reasoning( + Some("old reasoning".to_string()), + "old answer".to_string(), + Vec::new(), + ), + Message::user("new user".to_string()), + Message::assistant_with_reasoning( + Some("new reasoning".to_string()), + "new answer".to_string(), + Vec::new(), + ), + ]; + + let latest = MessageHelper::estimate_request_tokens( + &messages, + None, + RequestReasoningTokenPolicy::LatestTurnOnly, + ); + let skip_all = MessageHelper::estimate_request_tokens( + &messages, + None, + RequestReasoningTokenPolicy::SkipAll, + ); + + assert_eq!( + latest - skip_all, + TokenCounter::estimate_tokens("new reasoning") + ); + } +} diff --git a/src/crates/core/src/agentic/core/mod.rs b/src/crates/core/src/agentic/core/mod.rs index ad8b6a552..7d46528b5 100644 --- a/src/crates/core/src/agentic/core/mod.rs +++ b/src/crates/core/src/agentic/core/mod.rs @@ -15,7 +15,7 @@ pub use message::{ CompressedToolCall, CompressionEntry, CompressionPayload, Message, MessageContent, MessageRole, MessageSemanticKind, ToolCall, ToolResult, }; -pub use messages_helper::MessageHelper; +pub use messages_helper::{MessageHelper, RequestReasoningTokenPolicy}; pub use model_round::ModelRound; pub use prompt_markup::{ has_prompt_markup, is_system_reminder_only, render_system_reminder, render_user_query, diff --git a/src/crates/core/src/agentic/execution/execution_engine.rs b/src/crates/core/src/agentic/execution/execution_engine.rs index 35e463115..68efe2921 100644 --- a/src/crates/core/src/agentic/execution/execution_engine.rs +++ b/src/crates/core/src/agentic/execution/execution_engine.rs @@ -5,7 +5,10 @@ use super::round_executor::RoundExecutor; use super::types::{ExecutionContext, ExecutionResult, RoundContext}; use crate::agentic::agents::{get_agent_registry, PromptBuilderContext, RemoteExecutionHints}; -use crate::agentic::core::{Message, MessageContent, MessageHelper, MessageSemanticKind, Session}; +use crate::agentic::core::{ + Message, MessageContent, MessageHelper, MessageSemanticKind, RequestReasoningTokenPolicy, + Session, +}; use crate::agentic::events::{AgenticEvent, EventPriority, EventQueue}; use crate::agentic::image_analysis::{ build_multimodal_message_with_images, process_image_contexts_for_provider, ImageContextData, @@ -81,17 +84,14 @@ impl ExecutionEngine { } fn estimate_request_tokens_internal( - messages: &mut [Message], + messages: &[Message], tools: Option<&[ToolDefinition]>, ) -> usize { - let mut total: usize = messages.iter_mut().map(|m| m.get_tokens()).sum(); - total += 3; - - if let Some(tool_defs) = tools { - total += TokenCounter::estimate_tool_definitions_tokens(tool_defs); - } - - total + MessageHelper::estimate_request_tokens( + messages, + tools, + RequestReasoningTokenPolicy::LatestTurnOnly, + ) } /// Emergency truncation: drop oldest API rounds (assistant+tool pairs) @@ -152,7 +152,10 @@ impl ExecutionEngine { let tool_tokens = tools .map(TokenCounter::estimate_tool_definitions_tokens) .unwrap_or(0); - let preserved_tokens: usize = preserved.iter().map(|m| m.estimate_tokens()).sum::() + let preserved_tokens: usize = preserved + .iter() + .map(|m| m.estimate_tokens_with_reasoning(true)) + .sum::() + tool_tokens + 3; @@ -161,11 +164,14 @@ impl ExecutionEngine { + rounds .iter() .flat_map(|r| r.iter()) - .map(|m| m.estimate_tokens()) + .map(|m| m.estimate_tokens_with_reasoning(true)) .sum::(); while total_tokens > context_window && kept_start < rounds.len() { - let round_tokens: usize = rounds[kept_start].iter().map(|m| m.estimate_tokens()).sum(); + let round_tokens: usize = rounds[kept_start] + .iter() + .map(|m| m.estimate_tokens_with_reasoning(true)) + .sum(); total_tokens -= round_tokens; kept_start += 1; } @@ -1003,9 +1009,6 @@ impl ExecutionEngine { } }; - // Get configuration for whether to support preserving historical thinking content - let enable_thinking = ai_client.config.enable_thinking_process; - let support_preserved_thinking = ai_client.config.support_preserved_thinking; let model_context_window = ai_client.config.context_window as usize; let session_max_tokens = session.config.max_context_tokens; let context_window = model_context_window.min(session_max_tokens); @@ -1228,15 +1231,9 @@ impl ExecutionEngine { break; } - MessageHelper::compute_keep_thinking_flags( - &mut messages, - enable_thinking, - support_preserved_thinking, - ); - // Check and compress before sending AI request let mut current_tokens = - Self::estimate_request_tokens_internal(&mut messages, tool_definitions.as_deref()); + Self::estimate_request_tokens_internal(&messages, tool_definitions.as_deref()); debug!( "Round {} token usage before send: {} / {} tokens ({:.1}%)", round_index, @@ -1249,7 +1246,8 @@ impl ExecutionEngine { // considering full compression. This is a cheap, local-only // operation that can free significant tokens. let token_usage_ratio = current_tokens as f32 / context_window as f32; - if enable_context_compression && token_usage_ratio >= microcompact_config.trigger_ratio { + if enable_context_compression && token_usage_ratio >= microcompact_config.trigger_ratio + { if let Some(mc_result) = crate::agentic::session::compression::microcompact::microcompact_messages( &mut messages, @@ -1346,10 +1344,8 @@ impl ExecutionEngine { // L2: Emergency truncation — if tokens still exceed context_window // after all compression layers, drop oldest API rounds until we fit. - let post_compress_tokens = Self::estimate_request_tokens_internal( - &mut messages, - tool_definitions.as_deref(), - ); + let post_compress_tokens = + Self::estimate_request_tokens_internal(&messages, tool_definitions.as_deref()); if post_compress_tokens > context_window { warn!( "Round {} tokens ({}) still exceed context_window ({}) after compression, performing emergency truncation", @@ -1360,10 +1356,8 @@ impl ExecutionEngine { context_window, tool_definitions.as_deref(), ); - let after_truncate = Self::estimate_request_tokens_internal( - &mut messages, - tool_definitions.as_deref(), - ); + let after_truncate = + Self::estimate_request_tokens_internal(&messages, tool_definitions.as_deref()); info!( "Emergency truncation complete: tokens {} -> {}", post_compress_tokens, after_truncate diff --git a/src/crates/core/src/agentic/execution/round_executor.rs b/src/crates/core/src/agentic/execution/round_executor.rs index 7d492ce14..af8f223c7 100644 --- a/src/crates/core/src/agentic/execution/round_executor.rs +++ b/src/crates/core/src/agentic/execution/round_executor.rs @@ -432,27 +432,22 @@ impl RoundExecutor { stream_result .tool_calls .iter() - .map(|tc| { - crate::agentic::tools::pipeline::ToolExecutionResult { + .map(|tc| crate::agentic::tools::pipeline::ToolExecutionResult { + tool_id: tc.tool_id.clone(), + tool_name: tc.tool_name.clone(), + result: crate::agentic::core::ToolResult { tool_id: tc.tool_id.clone(), tool_name: tc.tool_name.clone(), - result: crate::agentic::core::ToolResult { - tool_id: tc.tool_id.clone(), - tool_name: tc.tool_name.clone(), - result: serde_json::json!({ - "error": e.to_string(), - "message": format!("Tool pipeline execution failed: {}", e) - }), - result_for_assistant: Some(format!( - "Tool execution failed: {}", - e - )), - is_error: true, - duration_ms: None, - image_attachments: None, - }, - execution_time_ms: 0, - } + result: serde_json::json!({ + "error": e.to_string(), + "message": format!("Tool pipeline execution failed: {}", e) + }), + result_for_assistant: Some(format!("Tool execution failed: {}", e)), + is_error: true, + duration_ms: None, + image_attachments: None, + }, + execution_time_ms: 0, }) .collect() } diff --git a/src/crates/core/src/agentic/session/compression/microcompact.rs b/src/crates/core/src/agentic/session/compression/microcompact.rs index 857f53021..c751b51e3 100644 --- a/src/crates/core/src/agentic/session/compression/microcompact.rs +++ b/src/crates/core/src/agentic/session/compression/microcompact.rs @@ -220,10 +220,7 @@ mod tests { #[test] fn no_op_when_within_keep_window() { - let mut messages = vec![ - make_tool_result("Read", "a"), - make_tool_result("Grep", "b"), - ]; + let mut messages = vec![make_tool_result("Read", "a"), make_tool_result("Grep", "b")]; let config = MicrocompactConfig { keep_recent: 5, diff --git a/src/crates/core/src/agentic/session/session_manager.rs b/src/crates/core/src/agentic/session/session_manager.rs index 5f93472da..8521fc939 100644 --- a/src/crates/core/src/agentic/session/session_manager.rs +++ b/src/crates/core/src/agentic/session/session_manager.rs @@ -668,7 +668,9 @@ impl SessionManager { "Session evicted from memory, restoring for model update: session_id={}", session_id ); - let _ = self.restore_session(&workspace_path.clone(), session_id).await; + let _ = self + .restore_session(&workspace_path.clone(), session_id) + .await; } } diff --git a/src/crates/core/src/agentic/tools/implementations/skills/builtin.rs b/src/crates/core/src/agentic/tools/implementations/skills/builtin.rs index d4b2ce46d..9996e7ecb 100644 --- a/src/crates/core/src/agentic/tools/implementations/skills/builtin.rs +++ b/src/crates/core/src/agentic/tools/implementations/skills/builtin.rs @@ -173,7 +173,10 @@ mod tests { assert_eq!(builtin_skill_group_key("xlsx"), Some("office")); assert_eq!(builtin_skill_group_key("find-skills"), Some("meta")); assert_eq!(builtin_skill_group_key("writing-skills"), Some("meta")); - assert_eq!(builtin_skill_group_key("agent-browser"), Some("computer-use")); + assert_eq!( + builtin_skill_group_key("agent-browser"), + Some("computer-use") + ); assert_eq!( builtin_skill_group_key("test-driven-development"), Some("superpowers") diff --git a/src/crates/core/src/agentic/tools/pipeline/tool_pipeline.rs b/src/crates/core/src/agentic/tools/pipeline/tool_pipeline.rs index 40ff96722..954a01500 100644 --- a/src/crates/core/src/agentic/tools/pipeline/tool_pipeline.rs +++ b/src/crates/core/src/agentic/tools/pipeline/tool_pipeline.rs @@ -362,7 +362,10 @@ impl ToolPipeline { let task_id = &task_ids[idx]; let (tool_id, tool_name) = if let Some(task) = self.state_manager.get_task(task_id) { - (task.tool_call.tool_id.clone(), task.tool_call.tool_name.clone()) + ( + task.tool_call.tool_id.clone(), + task.tool_call.tool_name.clone(), + ) } else { warn!("Task not found in state manager: {}", task_id); (task_id.clone(), "unknown".to_string()) @@ -407,7 +410,10 @@ impl ToolPipeline { let (tool_id, tool_name) = if let Some(task) = self.state_manager.get_task(&task_id) { - (task.tool_call.tool_id.clone(), task.tool_call.tool_name.clone()) + ( + task.tool_call.tool_id.clone(), + task.tool_call.tool_name.clone(), + ) } else { warn!("Task not found in state manager: {}", task_id); (task_id.clone(), "unknown".to_string()) diff --git a/src/crates/core/src/agentic/tools/workspace_paths.rs b/src/crates/core/src/agentic/tools/workspace_paths.rs index 63fc93a74..017328fe1 100644 --- a/src/crates/core/src/agentic/tools/workspace_paths.rs +++ b/src/crates/core/src/agentic/tools/workspace_paths.rs @@ -133,7 +133,10 @@ mod tests { let resolved = resolve_path_with_workspace("src/main.rs", Some(Path::new("/repo"))) .expect("path should resolve"); - assert_eq!(PathBuf::from(resolved), Path::new("/repo").join("src/main.rs")); + assert_eq!( + PathBuf::from(resolved), + Path::new("/repo").join("src/main.rs") + ); } #[test] diff --git a/src/crates/core/src/infrastructure/ai/client.rs b/src/crates/core/src/infrastructure/ai/client.rs index b36c87810..a9deb2422 100644 --- a/src/crates/core/src/infrastructure/ai/client.rs +++ b/src/crates/core/src/infrastructure/ai/client.rs @@ -1,1778 +1,92 @@ -//! AI client implementation - refactored version +//! AI client implementation. //! -//! Uses a modular architecture to separate provider-specific logic into the providers module - -use crate::infrastructure::ai::providers::anthropic::AnthropicMessageConverter; -use crate::infrastructure::ai::providers::gemini::GeminiMessageConverter; -use crate::infrastructure::ai::providers::openai::OpenAIMessageConverter; -use crate::infrastructure::ai::tool_call_accumulator::{PendingToolCall, ToolCallBoundary}; +//! The client module now acts as a small facade: +//! - `client/*` holds shared transport and aggregation utilities +//! - `providers/*` owns provider-specific request/response adaptation + +pub(crate) mod format; +pub(crate) mod healthcheck; +pub(crate) mod http; +pub(crate) mod quirks; +pub(crate) mod response_aggregator; +pub(crate) mod sse; +pub(crate) mod utils; + +use crate::infrastructure::ai::providers::{anthropic, gemini, openai}; use crate::service::config::ProxyConfig; use crate::util::types::*; -use ai_stream_handlers::{ - handle_anthropic_stream, handle_gemini_stream, handle_openai_stream, handle_responses_stream, - UnifiedResponse, -}; -use anyhow::{anyhow, Result}; -use futures::StreamExt; -use log::{debug, error, info, warn}; -use reqwest::{Client, Proxy}; -use serde::Deserialize; +use anyhow::Result; +use format::ApiFormat; +use reqwest::Client; use tokio::sync::mpsc; -/// Streamed response result with the parsed stream and optional raw SSE receiver +/// Streamed response result with the parsed stream and optional raw SSE receiver. pub struct StreamResponse { - /// Parsed response stream - pub stream: std::pin::Pin> + Send>>, - /// Raw SSE receiver (for error diagnostics) + pub stream: std::pin::Pin< + Box> + Send>, + >, pub raw_sse_rx: Option>, } #[derive(Debug, Clone)] -pub struct AIClient { - client: Client, - pub config: AIConfig, -} - -#[derive(Debug, Deserialize)] -struct OpenAIModelsResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct OpenAIModelEntry { - id: String, -} - -#[derive(Debug, Deserialize)] -struct AnthropicModelsResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct AnthropicModelEntry { - id: String, - #[serde(default)] - display_name: Option, -} - -#[derive(Debug, Deserialize)] -struct GeminiModelsResponse { - #[serde(default)] - models: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GeminiModelEntry { - name: String, - #[serde(default)] - display_name: Option, - #[serde(default, deserialize_with = "deserialize_null_as_default")] - supported_generation_methods: Vec, -} - -fn deserialize_null_as_default<'de, D, T>(deserializer: D) -> std::result::Result -where - D: serde::Deserializer<'de>, - T: Default + serde::Deserialize<'de>, -{ - Option::::deserialize(deserializer).map(|v| v.unwrap_or_default()) -} - -impl AIClient { - const TEST_IMAGE_EXPECTED_CODE: &'static str = "BYGR"; - const TEST_IMAGE_PNG_BASE64: &'static str = - "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAACBklEQVR42u3ZsREAIAwDMYf9dw4txwJupI7Wua+YZEPBfO91h4ZjAgQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABIAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAAAAAAAEDRZI3QGf7jDvEPAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAAAjABAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQALwuLkoG8OSfau4AAAAASUVORK5CYII="; - const STREAM_CONNECT_TIMEOUT_SECS: u64 = 10; - const HTTP_POOL_IDLE_TIMEOUT_SECS: u64 = 30; - const HTTP_TCP_KEEPALIVE_SECS: u64 = 60; - - fn image_test_response_matches_expected(response: &str) -> bool { - let upper = response.to_ascii_uppercase(); - - // Accept contiguous letters even when separated by spaces/punctuation. - let letters_only: String = upper.chars().filter(|c| c.is_ascii_alphabetic()).collect(); - if letters_only.contains(Self::TEST_IMAGE_EXPECTED_CODE) { - return true; - } - - let tokens: Vec<&str> = upper - .split(|c: char| !c.is_ascii_alphabetic()) - .filter(|s| !s.is_empty()) - .collect(); - - if tokens.contains(&Self::TEST_IMAGE_EXPECTED_CODE) { - return true; - } - - // Accept outputs like: "B Y G R". - let single_letter_stream: String = tokens - .iter() - .filter_map(|token| { - if token.len() == 1 { - let ch = token.chars().next()?; - if matches!(ch, 'R' | 'G' | 'B' | 'Y') { - return Some(ch); - } - } - None - }) - .collect(); - if single_letter_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) { - return true; - } - - // Accept outputs like: "Blue, Yellow, Green, Red". - let color_word_stream: String = tokens - .iter() - .filter_map(|token| match *token { - "RED" => Some('R'), - "GREEN" => Some('G'), - "BLUE" => Some('B'), - "YELLOW" => Some('Y'), - _ => None, - }) - .collect(); - if color_word_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) { - return true; - } - - // Last fallback: keep only RGBY letters and search code. - let color_letter_stream: String = upper - .chars() - .filter(|c| matches!(*c, 'R' | 'G' | 'B' | 'Y')) - .collect(); - color_letter_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) - } - - fn is_responses_api_format(api_format: &str) -> bool { - matches!( - api_format.to_ascii_lowercase().as_str(), - "response" | "responses" - ) - } - - fn is_gemini_api_format(api_format: &str) -> bool { - matches!( - api_format.to_ascii_lowercase().as_str(), - "gemini" | "google" - ) - } - - fn normalize_base_url_for_discovery(base_url: &str) -> String { - base_url - .trim() - .trim_end_matches('#') - .trim_end_matches('/') - .to_string() - } - - fn resolve_openai_models_url(&self) -> String { - let mut base = Self::normalize_base_url_for_discovery(&self.config.base_url); - - for suffix in ["/chat/completions", "/responses", "/models"] { - if base.ends_with(suffix) { - base.truncate(base.len() - suffix.len()); - break; - } - } - - if base.is_empty() { - return "models".to_string(); - } - - format!("{}/models", base) - } - - fn resolve_anthropic_models_url(&self) -> String { - let mut base = Self::normalize_base_url_for_discovery(&self.config.base_url); - - if base.ends_with("/v1/messages") { - base.truncate(base.len() - "/v1/messages".len()); - return format!("{}/v1/models", base); - } - - if base.ends_with("/v1/models") { - return base; - } - - if base.ends_with("/v1") { - return format!("{}/models", base); - } - - if base.is_empty() { - return "v1/models".to_string(); - } - - format!("{}/v1/models", base) - } - - fn dedupe_remote_models(models: Vec) -> Vec { - let mut seen = std::collections::HashSet::new(); - let mut deduped = Vec::new(); - - for model in models { - if seen.insert(model.id.clone()) { - deduped.push(model); - } - } - - deduped - } - - async fn list_openai_models(&self) -> Result> { - let url = self.resolve_openai_models_url(); - let response = self - .apply_openai_headers(self.client.get(&url)) - .send() - .await? - .error_for_status()?; - - let payload: OpenAIModelsResponse = response.json().await?; - Ok(Self::dedupe_remote_models( - payload - .data - .into_iter() - .map(|model| RemoteModelInfo { - id: model.id, - display_name: None, - }) - .collect(), - )) - } - - async fn list_anthropic_models(&self) -> Result> { - let url = self.resolve_anthropic_models_url(); - let response = self - .apply_anthropic_headers(self.client.get(&url), &url) - .send() - .await? - .error_for_status()?; - - let payload: AnthropicModelsResponse = response.json().await?; - Ok(Self::dedupe_remote_models( - payload - .data - .into_iter() - .map(|model| RemoteModelInfo { - id: model.id, - display_name: model.display_name, - }) - .collect(), - )) - } - - fn resolve_gemini_models_url(&self) -> String { - let base = Self::normalize_base_url_for_discovery(&self.config.base_url); - let base = Self::gemini_base_url(&base); - format!("{}/v1beta/models", base) - } - - async fn list_gemini_models(&self) -> Result> { - let url = self.resolve_gemini_models_url(); - debug!("Gemini models list URL: {}", url); - - let response = self - .apply_gemini_headers(self.client.get(&url)) - .send() - .await? - .error_for_status()?; - - let payload: GeminiModelsResponse = response.json().await?; - Ok(Self::dedupe_remote_models( - payload - .models - .into_iter() - .filter(|m| { - m.supported_generation_methods.is_empty() - || m.supported_generation_methods - .iter() - .any(|method| method == "generateContent") - }) - .map(|model| { - let id = model - .name - .strip_prefix("models/") - .unwrap_or(&model.name) - .to_string(); - RemoteModelInfo { - id, - display_name: model.display_name, - } - }) - .collect(), - )) - } - - /// Create an AIClient without proxy (backward compatible) - pub fn new(config: AIConfig) -> Self { - let skip_ssl_verify = config.skip_ssl_verify; - let client = Self::create_http_client(None, skip_ssl_verify); - Self { client, config } - } - - /// Create an AIClient with proxy configuration - pub fn new_with_proxy(config: AIConfig, proxy_config: Option) -> Self { - let skip_ssl_verify = config.skip_ssl_verify; - let client = Self::create_http_client(proxy_config, skip_ssl_verify); - Self { client, config } - } - - /// Create an HTTP client (supports proxy config and SSL verification control) - fn create_http_client(proxy_config: Option, skip_ssl_verify: bool) -> Client { - let mut builder = Client::builder() - // SSE requests can legitimately stay open for a long time while the model - // thinks or executes tools. Keep only connect timeout here and let the - // stream handlers enforce idle timeouts between chunks. - .connect_timeout(std::time::Duration::from_secs( - Self::STREAM_CONNECT_TIMEOUT_SECS, - )) - .user_agent("BitFun/1.0") - .pool_idle_timeout(std::time::Duration::from_secs( - Self::HTTP_POOL_IDLE_TIMEOUT_SECS, - )) - .pool_max_idle_per_host(4) - .tcp_keepalive(Some(std::time::Duration::from_secs( - Self::HTTP_TCP_KEEPALIVE_SECS, - ))) - .danger_accept_invalid_certs(skip_ssl_verify); - - if skip_ssl_verify { - warn!("SSL certificate verification disabled - security risk, use only in test environments"); - } - - // rustls mode does not support http2_keep_alive_interval/http2_keep_alive_timeout. - if let Some(proxy_cfg) = proxy_config { - if proxy_cfg.enabled && !proxy_cfg.url.is_empty() { - match Self::build_proxy(&proxy_cfg) { - Ok(proxy) => { - info!("Using proxy: {}", proxy_cfg.url); - builder = builder.proxy(proxy); - } - Err(e) => { - error!( - "Proxy configuration failed: {}, proceeding without proxy", - e - ); - builder = builder.no_proxy(); - } - } - } else { - builder = builder.no_proxy(); - } - } else { - builder = builder.no_proxy(); - } - - match builder.build() { - Ok(client) => client, - Err(e) => { - error!( - "HTTP client initialization failed: {}, using default client", - e - ); - Client::new() - } - } - } - - fn build_proxy(config: &ProxyConfig) -> Result { - let mut proxy = - Proxy::all(&config.url).map_err(|e| anyhow!("Failed to create proxy: {}", e))?; - - if let (Some(username), Some(password)) = (&config.username, &config.password) { - if !username.is_empty() && !password.is_empty() { - proxy = proxy.basic_auth(username, password); - debug!("Proxy authentication configured for user: {}", username); - } - } - - Ok(proxy) - } - - fn get_api_format(&self) -> &str { - &self.config.format - } - - /// Whether the URL is Alibaba DashScope API. - /// Alibaba DashScope uses `enable_thinking`=true/false for thinking, not the `thinking` object. - fn is_dashscope_url(url: &str) -> bool { - url.contains("dashscope.aliyuncs.com") - } - - /// Whether the URL is MiniMax API. - /// MiniMax (api.minimaxi.com) uses `reasoning_split=true` to enable streamed thinking content - /// delivered via `delta.reasoning_details` rather than the standard `reasoning_content` field. - fn is_minimax_url(url: &str) -> bool { - url.contains("api.minimaxi.com") - } - - /// Apply thinking-related fields onto the request body (mutates `request_body`). - /// - /// * `enable` - whether thinking process is enabled - /// * `url` - request URL - /// * `model_name` - model name (e.g. for Claude budget_tokens in Anthropic format) - /// * `api_format` - "openai" or "anthropic" - /// * `max_tokens` - optional max_tokens (for Anthropic Claude budget_tokens) - fn apply_thinking_fields( - request_body: &mut serde_json::Value, - enable: bool, - url: &str, - model_name: &str, - api_format: &str, - max_tokens: Option, - ) { - if Self::is_dashscope_url(url) && api_format.eq_ignore_ascii_case("openai") { - request_body["enable_thinking"] = serde_json::json!(enable); - return; - } - if Self::is_minimax_url(url) && api_format.eq_ignore_ascii_case("openai") { - if enable { - request_body["reasoning_split"] = serde_json::json!(true); - } - return; - } - let thinking_value = if enable { - if api_format.eq_ignore_ascii_case("anthropic") && model_name.starts_with("claude") { - let mut obj = serde_json::map::Map::new(); - obj.insert( - "type".to_string(), - serde_json::Value::String("enabled".to_string()), - ); - if let Some(m) = max_tokens { - obj.insert( - "budget_tokens".to_string(), - serde_json::json!(10000u32.min(m * 3 / 4)), - ); - } - serde_json::Value::Object(obj) - } else { - serde_json::json!({ "type": "enabled" }) - } - } else { - serde_json::json!({ "type": "disabled" }) - }; - request_body["thinking"] = thinking_value; - } - - /// Whether to append the `tool_stream` request field. - /// - /// Only Zhipu (https://open.bigmodel.cn) uses this field; and only for GLM models (pure version >= 4.6). - /// Adding this parameter for non-Zhipu APIs may cause abnormal behavior: - /// 1) incomplete output; (Aliyun Coding Plan, 2026-02-28) - /// 2) extra `` prefix on some tool names. (Aliyun Coding Plan, 2026-02-28) - fn should_append_tool_stream(url: &str, model_name: &str) -> bool { - if !url.contains("open.bigmodel.cn") { - return false; - } - Self::parse_glm_major_minor(model_name) - .map(|(major, minor)| major > 4 || (major == 4 && minor >= 6)) - .unwrap_or(false) - } - - /// Parse strict `glm-[.]` from model names like: - /// - glm-4.6 - /// - glm-5 - /// - /// Models with non-numeric suffixes are treated as not requiring this GLM-specific field, e.g.: - /// - glm-4.6-flash - /// - glm-4.5v - fn parse_glm_major_minor(model_name: &str) -> Option<(u32, u32)> { - let version_part = model_name.strip_prefix("glm-")?; - - if version_part.is_empty() { - return None; - } - - let mut parts = version_part.split('.'); - let major: u32 = parts.next()?.parse().ok()?; - let minor: u32 = match parts.next() { - Some(v) => v.parse().ok()?, - None => 0, - }; - - // Only allow one numeric segment after the decimal point. - if parts.next().is_some() { - return None; - } - - Some((major, minor)) - } - - /// Determine whether to use merge mode - /// - /// true: apply default headers first, then custom headers (custom can override) - /// false: if custom headers exist, replace defaults entirely - /// Default is merge mode - fn is_merge_headers_mode(&self) -> bool { - // Default to merge mode; use replace mode only when explicitly set to "replace" - self.config.custom_headers_mode.as_deref() != Some("replace") - } - - /// Apply custom headers to the builder - fn apply_custom_headers( - &self, - mut builder: reqwest::RequestBuilder, - ) -> reqwest::RequestBuilder { - if let Some(custom_headers) = &self.config.custom_headers { - if !custom_headers.is_empty() { - for (key, value) in custom_headers { - builder = builder.header(key.as_str(), value.as_str()); - } - } - } - builder - } - - /// Apply OpenAI-style request headers (merge/replace). - fn apply_openai_headers( - &self, - mut builder: reqwest::RequestBuilder, - ) -> reqwest::RequestBuilder { - let has_custom_headers = self - .config - .custom_headers - .as_ref() - .is_some_and(|h| !h.is_empty()); - let is_merge_mode = self.is_merge_headers_mode(); - - if has_custom_headers && !is_merge_mode { - return self.apply_custom_headers(builder); - } - - builder = builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", self.config.api_key)); - - if self.config.base_url.contains("openbitfun.com") { - builder = builder.header("X-Verification-Code", "from_bitfun"); - } - - if has_custom_headers && is_merge_mode { - builder = self.apply_custom_headers(builder); - } - - builder - } - - /// Apply Anthropic-style request headers (merge/replace). - fn apply_anthropic_headers( - &self, - mut builder: reqwest::RequestBuilder, - url: &str, - ) -> reqwest::RequestBuilder { - let has_custom_headers = self - .config - .custom_headers - .as_ref() - .is_some_and(|h| !h.is_empty()); - let is_merge_mode = self.is_merge_headers_mode(); - - if has_custom_headers && !is_merge_mode { - return self.apply_custom_headers(builder); - } - - builder = builder.header("Content-Type", "application/json"); - - if url.contains("bigmodel.cn") { - builder = builder.header("Authorization", format!("Bearer {}", self.config.api_key)); - } else { - builder = builder - .header("x-api-key", &self.config.api_key) - .header("anthropic-version", "2023-06-01"); - } - - if url.contains("openbitfun.com") { - builder = builder.header("X-Verification-Code", "from_bitfun"); - } - - if has_custom_headers && is_merge_mode { - builder = self.apply_custom_headers(builder); - } - - builder - } - - /// Apply Gemini-style request headers (merge/replace). - fn apply_gemini_headers( - &self, - mut builder: reqwest::RequestBuilder, - ) -> reqwest::RequestBuilder { - let has_custom_headers = self - .config - .custom_headers - .as_ref() - .is_some_and(|h| !h.is_empty()); - let is_merge_mode = self.is_merge_headers_mode(); - - if has_custom_headers && !is_merge_mode { - return self.apply_custom_headers(builder); - } - - builder = builder - .header("Content-Type", "application/json") - .header("x-goog-api-key", &self.config.api_key) - .header("Authorization", format!("Bearer {}", self.config.api_key)); - - if self.config.base_url.contains("openbitfun.com") { - builder = builder.header("X-Verification-Code", "from_bitfun"); - } - - if has_custom_headers && is_merge_mode { - builder = self.apply_custom_headers(builder); - } - - builder - } - - fn merge_json_value(target: &mut serde_json::Value, overlay: serde_json::Value) { - match (target, overlay) { - (serde_json::Value::Object(target_map), serde_json::Value::Object(overlay_map)) => { - for (key, value) in overlay_map { - let entry = target_map.entry(key).or_insert(serde_json::Value::Null); - Self::merge_json_value(entry, value); - } - } - (target_slot, overlay_value) => { - *target_slot = overlay_value; - } - } - } - - fn ensure_gemini_generation_config( - request_body: &mut serde_json::Value, - ) -> &mut serde_json::Map { - if !request_body - .get("generationConfig") - .is_some_and(serde_json::Value::is_object) - { - request_body["generationConfig"] = serde_json::json!({}); - } - - request_body["generationConfig"] - .as_object_mut() - .expect("generationConfig must be an object") - } - - fn insert_gemini_generation_field( - request_body: &mut serde_json::Value, - key: &str, - value: serde_json::Value, - ) { - Self::ensure_gemini_generation_config(request_body).insert(key.to_string(), value); - } - - fn normalize_gemini_stop_sequences(value: &serde_json::Value) -> Option { - match value { - serde_json::Value::String(sequence) => { - Some(serde_json::Value::Array(vec![serde_json::Value::String( - sequence.clone(), - )])) - } - serde_json::Value::Array(items) => { - let sequences = items - .iter() - .filter_map(|item| item.as_str().map(|sequence| sequence.to_string())) - .map(serde_json::Value::String) - .collect::>(); - - if sequences.is_empty() { - None - } else { - Some(serde_json::Value::Array(sequences)) - } - } - _ => None, - } - } - - fn apply_gemini_response_format_translation( - request_body: &mut serde_json::Value, - response_format: &serde_json::Value, - ) -> bool { - match response_format { - serde_json::Value::String(kind) if matches!(kind.as_str(), "json" | "json_object") => { - Self::insert_gemini_generation_field( - request_body, - "responseMimeType", - serde_json::Value::String("application/json".to_string()), - ); - true - } - serde_json::Value::Object(map) => { - let Some(kind) = map.get("type").and_then(serde_json::Value::as_str) else { - return false; - }; - - match kind { - "json" | "json_object" => { - Self::insert_gemini_generation_field( - request_body, - "responseMimeType", - serde_json::Value::String("application/json".to_string()), - ); - true - } - "json_schema" => { - Self::insert_gemini_generation_field( - request_body, - "responseMimeType", - serde_json::Value::String("application/json".to_string()), - ); - - if let Some(schema) = map - .get("json_schema") - .and_then(serde_json::Value::as_object) - .and_then(|json_schema| json_schema.get("schema")) - .or_else(|| map.get("schema")) - { - Self::insert_gemini_generation_field( - request_body, - "responseJsonSchema", - GeminiMessageConverter::sanitize_schema(schema.clone()), - ); - } - - true - } - _ => false, - } - } - _ => false, - } - } - - fn translate_gemini_extra_body( - request_body: &mut serde_json::Value, - extra_obj: &mut serde_json::Map, - ) { - if let Some(max_tokens) = extra_obj.remove("max_tokens") { - Self::insert_gemini_generation_field(request_body, "maxOutputTokens", max_tokens); - } - - if let Some(temperature) = extra_obj.remove("temperature") { - Self::insert_gemini_generation_field(request_body, "temperature", temperature); - } - - let top_p = extra_obj - .remove("top_p") - .or_else(|| extra_obj.remove("topP")); - if let Some(top_p) = top_p { - Self::insert_gemini_generation_field(request_body, "topP", top_p); - } - - if let Some(stop_sequences) = extra_obj - .get("stop") - .and_then(Self::normalize_gemini_stop_sequences) - { - extra_obj.remove("stop"); - Self::insert_gemini_generation_field(request_body, "stopSequences", stop_sequences); - } - - if let Some(response_mime_type) = extra_obj - .remove("responseMimeType") - .or_else(|| extra_obj.remove("response_mime_type")) - { - Self::insert_gemini_generation_field( - request_body, - "responseMimeType", - response_mime_type, - ); - } - - if let Some(response_schema) = extra_obj - .remove("responseJsonSchema") - .or_else(|| extra_obj.remove("responseSchema")) - .or_else(|| extra_obj.remove("response_schema")) - { - Self::insert_gemini_generation_field( - request_body, - "responseJsonSchema", - GeminiMessageConverter::sanitize_schema(response_schema), - ); - } - - if let Some(response_format) = extra_obj.get("response_format").cloned() { - if Self::apply_gemini_response_format_translation(request_body, &response_format) { - extra_obj.remove("response_format"); - } - } - } - - fn unified_usage_to_gemini_usage(usage: ai_stream_handlers::UnifiedTokenUsage) -> GeminiUsage { - GeminiUsage { - prompt_token_count: usage.prompt_token_count, - candidates_token_count: usage.candidates_token_count, - total_token_count: usage.total_token_count, - reasoning_token_count: usage.reasoning_token_count, - cached_content_token_count: usage.cached_content_token_count, - } - } - - /// Build an OpenAI-format request body - fn build_openai_request_body( - &self, - url: &str, - openai_messages: Vec, - openai_tools: Option>, - extra_body: Option, - ) -> serde_json::Value { - let mut request_body = serde_json::json!({ - "model": self.config.model, - "messages": openai_messages, - "stream": true - }); - - let model_name = self.config.model.to_lowercase(); - - if Self::should_append_tool_stream(url, &model_name) { - request_body["tool_stream"] = serde_json::Value::Bool(true); - } - - Self::apply_thinking_fields( - &mut request_body, - self.config.enable_thinking_process, - url, - &model_name, - "openai", - self.config.max_tokens, - ); - - if let Some(max_tokens) = self.config.max_tokens { - request_body["max_tokens"] = serde_json::json!(max_tokens); - } - - if let Some(extra) = extra_body { - if let Some(extra_obj) = extra.as_object() { - for (key, value) in extra_obj { - request_body[key] = value.clone(); - } - debug!(target: "ai::openai_stream_request", "Applied extra_body overrides: {:?}", extra_obj.keys().collect::>()); - } - } - - // This client currently consumes only the first choice in stream handling. - // Remove custom n override and keep provider defaults. - if let Some(request_obj) = request_body.as_object_mut() { - if let Some(existing_n) = request_obj.remove("n") { - warn!( - target: "ai::openai_stream_request", - "Removed custom request field n={} because the stream processor only handles the first choice", - existing_n - ); - } - } - - debug!(target: "ai::openai_stream_request", - "OpenAI stream request body (excluding tools):\n{}", - serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "serialization failed".to_string()) - ); - - if let Some(tools) = openai_tools { - let tool_names = tools - .iter() - .map(Self::extract_openai_tool_name) - .collect::>(); - debug!(target: "ai::openai_stream_request", "\ntools: {:?}", tool_names); - if !tools.is_empty() { - request_body["tools"] = serde_json::Value::Array(tools); - // Respect `extra_body` overrides (e.g. tool_choice="required") when present. - let has_tool_choice = request_body - .get("tool_choice") - .is_some_and(|v| !v.is_null()); - if !has_tool_choice { - request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); - } - } - } - - request_body - } - - /// Build a Responses API request body. - fn build_responses_request_body( - &self, - instructions: Option, - response_input: Vec, - openai_tools: Option>, - extra_body: Option, - ) -> serde_json::Value { - let mut request_body = serde_json::json!({ - "model": self.config.model, - "input": response_input, - "stream": true - }); - - if let Some(instructions) = instructions.filter(|value| !value.trim().is_empty()) { - request_body["instructions"] = serde_json::Value::String(instructions); - } - - if let Some(max_tokens) = self.config.max_tokens { - request_body["max_output_tokens"] = serde_json::json!(max_tokens); - } - - if let Some(ref effort) = self.config.reasoning_effort { - request_body["reasoning"] = serde_json::json!({ - "effort": effort, - "summary": "auto" - }); - } - - if let Some(extra) = extra_body { - if let Some(extra_obj) = extra.as_object() { - for (key, value) in extra_obj { - request_body[key] = value.clone(); - } - debug!( - target: "ai::responses_stream_request", - "Applied extra_body overrides: {:?}", - extra_obj.keys().collect::>() - ); - } - } - - debug!( - target: "ai::responses_stream_request", - "Responses stream request body (excluding tools):\n{}", - serde_json::to_string_pretty(&request_body) - .unwrap_or_else(|_| "serialization failed".to_string()) - ); - - if let Some(tools) = openai_tools { - let tool_names = tools - .iter() - .map(Self::extract_openai_tool_name) - .collect::>(); - debug!(target: "ai::responses_stream_request", "\ntools: {:?}", tool_names); - if !tools.is_empty() { - request_body["tools"] = serde_json::Value::Array(tools); - // Respect `extra_body` overrides (e.g. tool_choice="required") when present. - let has_tool_choice = request_body - .get("tool_choice") - .is_some_and(|v| !v.is_null()); - if !has_tool_choice { - request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); - } - } - } - - request_body - } - - /// Build an Anthropic-format request body - fn build_anthropic_request_body( - &self, - url: &str, - system_message: Option, - anthropic_messages: Vec, - anthropic_tools: Option>, - extra_body: Option, - ) -> serde_json::Value { - let max_tokens = self.config.max_tokens.unwrap_or(8192); - - let mut request_body = serde_json::json!({ - "model": self.config.model, - "messages": anthropic_messages, - "max_tokens": max_tokens, - "stream": true - }); - - let model_name = self.config.model.to_lowercase(); - - // Zhipu extension: only set `tool_stream` for open.bigmodel.cn. - if Self::should_append_tool_stream(url, &model_name) { - request_body["tool_stream"] = serde_json::Value::Bool(true); - } - - Self::apply_thinking_fields( - &mut request_body, - self.config.enable_thinking_process, - url, - &model_name, - "anthropic", - Some(max_tokens), - ); - - if let Some(system) = system_message { - request_body["system"] = serde_json::Value::String(system); - } - - if let Some(extra) = extra_body { - if let Some(extra_obj) = extra.as_object() { - for (key, value) in extra_obj { - request_body[key] = value.clone(); - } - debug!(target: "ai::anthropic_stream_request", "Applied extra_body overrides: {:?}", extra_obj.keys().collect::>()); - } - } - - debug!(target: "ai::anthropic_stream_request", - "Anthropic stream request body (excluding tools):\n{}", - serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "serialization failed".to_string()) - ); - - if let Some(tools) = anthropic_tools { - let tool_names = tools - .iter() - .map(Self::extract_anthropic_tool_name) - .collect::>(); - debug!(target: "ai::anthropic_stream_request", "\ntools: {:?}", tool_names); - if !tools.is_empty() { - request_body["tools"] = serde_json::Value::Array(tools); - } - } - - request_body - } - - /// Build a Gemini-format request body. - fn build_gemini_request_body( - &self, - system_instruction: Option, - contents: Vec, - gemini_tools: Option>, - extra_body: Option, - ) -> serde_json::Value { - let mut request_body = serde_json::json!({ - "contents": contents, - }); - - if let Some(system_instruction) = system_instruction { - request_body["systemInstruction"] = system_instruction; - } - - if let Some(max_tokens) = self.config.max_tokens { - Self::insert_gemini_generation_field( - &mut request_body, - "maxOutputTokens", - serde_json::json!(max_tokens), - ); - } - - if let Some(temperature) = self.config.temperature { - Self::insert_gemini_generation_field( - &mut request_body, - "temperature", - serde_json::json!(temperature), - ); - } - - if let Some(top_p) = self.config.top_p { - Self::insert_gemini_generation_field( - &mut request_body, - "topP", - serde_json::json!(top_p), - ); - } - - if self.config.enable_thinking_process { - Self::insert_gemini_generation_field( - &mut request_body, - "thinkingConfig", - serde_json::json!({ - "includeThoughts": true, - }), - ); - } - - if let Some(tools) = gemini_tools { - let tool_names = tools - .iter() - .flat_map(|tool| { - if let Some(declarations) = tool - .get("functionDeclarations") - .and_then(|value| value.as_array()) - { - declarations - .iter() - .filter_map(|declaration| { - declaration - .get("name") - .and_then(|value| value.as_str()) - .map(str::to_string) - }) - .collect::>() - } else { - tool.as_object() - .into_iter() - .flat_map(|map| map.keys().cloned()) - .collect::>() - } - }) - .collect::>(); - debug!(target: "ai::gemini_stream_request", "\ntools: {:?}", tool_names); - - if !tools.is_empty() { - request_body["tools"] = serde_json::Value::Array(tools); - let has_function_declarations = request_body["tools"] - .as_array() - .map(|tools| { - tools - .iter() - .any(|tool| tool.get("functionDeclarations").is_some()) - }) - .unwrap_or(false); - - if has_function_declarations { - request_body["toolConfig"] = serde_json::json!({ - "functionCallingConfig": { - "mode": "AUTO" - } - }); - } - } - } - - if let Some(extra) = extra_body { - if let Some(mut extra_obj) = extra.as_object().cloned() { - Self::translate_gemini_extra_body(&mut request_body, &mut extra_obj); - let override_keys = extra_obj.keys().cloned().collect::>(); - - for (key, value) in extra_obj { - if let Some(request_obj) = request_body.as_object_mut() { - let target = request_obj.entry(key).or_insert(serde_json::Value::Null); - Self::merge_json_value(target, value); - } - } - debug!( - target: "ai::gemini_stream_request", - "Applied extra_body overrides: {:?}", - override_keys - ); - } - } - - debug!( - target: "ai::gemini_stream_request", - "Gemini stream request body:\n{}", - serde_json::to_string_pretty(&request_body) - .unwrap_or_else(|_| "serialization failed".to_string()) - ); - - request_body - } - - fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { - let trimmed = base_url.trim().trim_end_matches('/'); - if trimmed.is_empty() { - return String::new(); - } - - let base = Self::gemini_base_url(trimmed); - let encoded_model = urlencoding::encode(model_name.trim()); - format!( - "{}/v1beta/models/{}:streamGenerateContent?alt=sse", - base, encoded_model - ) - } - - /// Strip /v1beta, /models/... and similar suffixes from a gemini URL, - /// returning only the bare host root (e.g. https://generativelanguage.googleapis.com). - fn gemini_base_url(url: &str) -> &str { - let mut u = url; - if let Some(pos) = u.find("/v1beta") { - u = &u[..pos]; - } - if let Some(pos) = u.find("/models/") { - u = &u[..pos]; - } - u.trim_end_matches('/') - } - - fn extract_openai_tool_name(tool: &serde_json::Value) -> String { - tool.get("function") - .and_then(|f| f.get("name")) - .and_then(|n| n.as_str()) - .unwrap_or("unknown") - .to_string() - } - - fn extract_anthropic_tool_name(tool: &serde_json::Value) -> String { - tool.get("name") - .and_then(|n| n.as_str()) - .unwrap_or("unknown") - .to_string() - } - - /// Send a streaming message request - /// - /// Returns `StreamResponse` with: - /// - `stream`: parsed response stream - /// - `raw_sse_rx`: raw SSE receiver (for collecting data during error diagnostics) - pub async fn send_message_stream( - &self, - messages: Vec, - tools: Option>, - ) -> Result { - let custom_body = self.config.custom_request_body.clone(); - self.send_message_stream_with_extra_body(messages, tools, custom_body) - .await - } - - /// Send a streaming message request with extra request body overrides - /// - /// Returns `StreamResponse` with: - /// - `stream`: parsed response stream - /// - `raw_sse_rx`: raw SSE receiver (for collecting data during error diagnostics) - pub async fn send_message_stream_with_extra_body( - &self, - messages: Vec, - tools: Option>, - extra_body: Option, - ) -> Result { - let max_tries = 3; - match self.get_api_format().to_lowercase().as_str() { - "openai" => { - self.send_openai_stream(messages, tools, extra_body, max_tries) - .await - } - format if Self::is_gemini_api_format(format) => { - self.send_gemini_stream(messages, tools, extra_body, max_tries) - .await - } - format if Self::is_responses_api_format(format) => { - self.send_responses_stream(messages, tools, extra_body, max_tries) - .await - } - "anthropic" => { - self.send_anthropic_stream(messages, tools, extra_body, max_tries) - .await - } - _ => Err(anyhow!("Unknown API format: {}", self.get_api_format())), - } - } - - /// Send an OpenAI streaming request with retries - /// - /// # Parameters - /// - `messages`: message list - /// - `tools`: tool definitions - /// - `extra_body`: extra request body parameters - /// - `max_tries`: max attempts (including the first) - async fn send_openai_stream( - &self, - messages: Vec, - tools: Option>, - extra_body: Option, - max_tries: usize, - ) -> Result { - let url = self.config.request_url.clone(); - debug!( - "OpenAI config: model={}, request_url={}, max_tries={}", - self.config.model, self.config.request_url, max_tries - ); - - // Use OpenAI message converter - let openai_messages = OpenAIMessageConverter::convert_messages(messages); - let openai_tools = OpenAIMessageConverter::convert_tools(tools); - - // Build request body - let request_body = - self.build_openai_request_body(&url, openai_messages, openai_tools, extra_body); - - let mut last_error = None; - let base_wait_time_ms = 500; - - for attempt in 0..max_tries { - let request_start_time = std::time::Instant::now(); - - // Send request - apply request headers - let request_builder = self.apply_openai_headers(self.client.post(&url)); - let response_result = request_builder.json(&request_body).send().await; - - let response = match response_result { - Ok(resp) => { - let connect_time = request_start_time.elapsed().as_millis(); - let status = resp.status(); - - if status.is_client_error() { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - error!( - "OpenAI Streaming API client error {}: {}", - status, error_text - ); - return Err(anyhow!( - "OpenAI Streaming API client error {}: {}", - status, - error_text - )); - } - - if status.is_success() { - debug!( - "Stream request connected: {}ms, status: {}, attempt: {}/{}", - connect_time, - status, - attempt + 1, - max_tries - ); - resp - } else { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - let error = - anyhow!("OpenAI Streaming API error {}: {}", status, error_text); - warn!( - "Stream request failed (attempt {}/{}): {}", - attempt + 1, - max_tries, - error - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - } - Err(e) => { - let connect_time = request_start_time.elapsed().as_millis(); - let error = anyhow!("Stream request connection failed: {}", e); - warn!( - "Stream request connection failed: {}ms, attempt {}/{}, error: {}", - connect_time, - attempt + 1, - max_tries, - e - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - }; - - // Success: create channels and return - let (tx, rx) = mpsc::unbounded_channel(); - let (tx_raw, rx_raw) = mpsc::unbounded_channel(); - - tokio::spawn(handle_openai_stream( - response, - tx, - Some(tx_raw), - self.config.inline_think_in_text, - )); - - return Ok(StreamResponse { - stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), - raw_sse_rx: Some(rx_raw), - }); - } - - let error_msg = format!( - "Stream request failed after {} attempts: {}", - max_tries, - last_error.unwrap_or_else(|| anyhow!("Unknown error")) - ); - error!("{}", error_msg); - Err(anyhow!(error_msg)) - } - - /// Send a Gemini streaming request with retries. - async fn send_gemini_stream( - &self, - messages: Vec, - tools: Option>, - extra_body: Option, - max_tries: usize, - ) -> Result { - let url = Self::resolve_gemini_request_url(&self.config.request_url, &self.config.model); - debug!( - "Gemini config: model={}, request_url={}, max_tries={}", - self.config.model, url, max_tries - ); - - let (system_instruction, contents) = - GeminiMessageConverter::convert_messages(messages, &self.config.model); - let gemini_tools = GeminiMessageConverter::convert_tools(tools); - let request_body = - self.build_gemini_request_body(system_instruction, contents, gemini_tools, extra_body); - - let mut last_error = None; - let base_wait_time_ms = 500; - - for attempt in 0..max_tries { - let request_start_time = std::time::Instant::now(); - let request_builder = self.apply_gemini_headers(self.client.post(&url)); - let response_result = request_builder.json(&request_body).send().await; - - let response = match response_result { - Ok(resp) => { - let connect_time = request_start_time.elapsed().as_millis(); - let status = resp.status(); - - if status.is_client_error() { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - error!( - "Gemini Streaming API client error {}: {}", - status, error_text - ); - return Err(anyhow!( - "Gemini Streaming API client error {}: {}", - status, - error_text - )); - } - - if status.is_success() { - debug!( - "Gemini stream request connected: {}ms, status: {}, attempt: {}/{}", - connect_time, - status, - attempt + 1, - max_tries - ); - resp - } else { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - let error = - anyhow!("Gemini Streaming API error {}: {}", status, error_text); - warn!( - "Gemini stream request failed: {}ms, attempt {}/{}, error: {}", - connect_time, - attempt + 1, - max_tries, - error - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!( - "Retrying Gemini after {}ms (attempt {})", - delay_ms, - attempt + 2 - ); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - } - Err(e) => { - let connect_time = request_start_time.elapsed().as_millis(); - let error = anyhow!("Gemini stream request connection failed: {}", e); - warn!( - "Gemini stream request connection failed: {}ms, attempt {}/{}, error: {}", - connect_time, - attempt + 1, - max_tries, - e - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!( - "Retrying Gemini after {}ms (attempt {})", - delay_ms, - attempt + 2 - ); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - }; - - let (tx, rx) = mpsc::unbounded_channel(); - let (tx_raw, rx_raw) = mpsc::unbounded_channel(); - - tokio::spawn(handle_gemini_stream(response, tx, Some(tx_raw))); - - return Ok(StreamResponse { - stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), - raw_sse_rx: Some(rx_raw), - }); - } - - let error_msg = format!( - "Gemini stream request failed after {} attempts: {}", - max_tries, - last_error.unwrap_or_else(|| anyhow!("Unknown error")) - ); - error!("{}", error_msg); - Err(anyhow!(error_msg)) - } - - /// Send a Responses API streaming request with retries. - async fn send_responses_stream( - &self, - messages: Vec, - tools: Option>, - extra_body: Option, - max_tries: usize, - ) -> Result { - let url = self.config.request_url.clone(); - debug!( - "Responses config: model={}, request_url={}, max_tries={}", - self.config.model, self.config.request_url, max_tries - ); - - let (instructions, response_input) = - OpenAIMessageConverter::convert_messages_to_responses_input(messages); - let openai_tools = OpenAIMessageConverter::convert_tools(tools); - let request_body = self.build_responses_request_body( - instructions, - response_input, - openai_tools, - extra_body, - ); - - let mut last_error = None; - let base_wait_time_ms = 500; - - for attempt in 0..max_tries { - let request_start_time = std::time::Instant::now(); - let request_builder = self.apply_openai_headers(self.client.post(&url)); - let response_result = request_builder.json(&request_body).send().await; - - let response = match response_result { - Ok(resp) => { - let connect_time = request_start_time.elapsed().as_millis(); - let status = resp.status(); - - if status.is_client_error() { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - error!("Responses API client error {}: {}", status, error_text); - return Err(anyhow!( - "Responses API client error {}: {}", - status, - error_text - )); - } - - if status.is_success() { - debug!( - "Responses request connected: {}ms, status: {}, attempt: {}/{}", - connect_time, - status, - attempt + 1, - max_tries - ); - resp - } else { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - let error = anyhow!("Responses API error {}: {}", status, error_text); - warn!( - "Responses request failed (attempt {}/{}): {}", - attempt + 1, - max_tries, - error - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - } - Err(e) => { - let connect_time = request_start_time.elapsed().as_millis(); - let error = anyhow!("Responses request connection failed: {}", e); - warn!( - "Responses request connection failed: {}ms, attempt {}/{}, error: {}", - connect_time, - attempt + 1, - max_tries, - e - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - }; +pub struct AIClient { + pub(crate) client: Client, + pub config: AIConfig, +} - let (tx, rx) = mpsc::unbounded_channel(); - let (tx_raw, rx_raw) = mpsc::unbounded_channel(); +impl AIClient { + pub(crate) const TEST_IMAGE_EXPECTED_CODE: &'static str = "BYGR"; + pub(crate) const TEST_IMAGE_PNG_BASE64: &'static str = + "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAACBklEQVR42u3ZsREAIAwDMYf9dw4txwJupI7Wua+YZEPBfO91h4ZjAgQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABIAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAAAAAAAEDRZI3QGf7jDvEPAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAAAjABAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQALwuLkoG8OSfau4AAAAASUVORK5CYII="; + pub(crate) const STREAM_CONNECT_TIMEOUT_SECS: u64 = 10; + pub(crate) const HTTP_POOL_IDLE_TIMEOUT_SECS: u64 = 30; + pub(crate) const HTTP_TCP_KEEPALIVE_SECS: u64 = 60; - tokio::spawn(handle_responses_stream(response, tx, Some(tx_raw))); + /// Create an AIClient without proxy. + pub fn new(config: AIConfig) -> Self { + let client = http::create_http_client(None, config.skip_ssl_verify); + Self { client, config } + } - return Ok(StreamResponse { - stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), - raw_sse_rx: Some(rx_raw), - }); - } + /// Create an AIClient with proxy configuration. + pub fn new_with_proxy(config: AIConfig, proxy_config: Option) -> Self { + let client = http::create_http_client(proxy_config, config.skip_ssl_verify); + Self { client, config } + } - let error_msg = format!( - "Responses request failed after {} attempts: {}", - max_tries, - last_error.unwrap_or_else(|| anyhow!("Unknown error")) - ); - error!("{}", error_msg); - Err(anyhow!(error_msg)) + pub async fn send_message_stream( + &self, + messages: Vec, + tools: Option>, + ) -> Result { + let custom_body = self.config.custom_request_body.clone(); + self.send_message_stream_with_extra_body(messages, tools, custom_body) + .await } - /// Send an Anthropic streaming request with retries - /// - /// # Parameters - /// - `messages`: message list - /// - `tools`: tool definitions - /// - `extra_body`: extra request body parameters - /// - `max_tries`: max attempts (including the first) - async fn send_anthropic_stream( + pub async fn send_message_stream_with_extra_body( &self, messages: Vec, tools: Option>, extra_body: Option, - max_tries: usize, ) -> Result { - let url = self.config.request_url.clone(); - debug!( - "Anthropic config: model={}, request_url={}, max_tries={}", - self.config.model, self.config.request_url, max_tries - ); - - // Use Anthropic message converter - let (system_message, anthropic_messages) = - AnthropicMessageConverter::convert_messages(messages); - let anthropic_tools = AnthropicMessageConverter::convert_tools(tools); - - // Build request body - let request_body = self.build_anthropic_request_body( - &url, - system_message, - anthropic_messages, - anthropic_tools, - extra_body, - ); - - let mut last_error = None; - let base_wait_time_ms = 500; - - for attempt in 0..max_tries { - let request_start_time = std::time::Instant::now(); - - // Send request - apply Anthropic-style request headers - let request_builder = self.apply_anthropic_headers(self.client.post(&url), &url); - let response_result = request_builder.json(&request_body).send().await; - - let response = match response_result { - Ok(resp) => { - let connect_time = request_start_time.elapsed().as_millis(); - let status = resp.status(); - - if status.is_client_error() { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - error!( - "Anthropic Streaming API client error {}: {}", - status, error_text - ); - return Err(anyhow!( - "Anthropic Streaming API client error {}: {}", - status, - error_text - )); - } - - if status.is_success() { - debug!( - "Stream request connected: {}ms, status: {}, attempt: {}/{}", - connect_time, - status, - attempt + 1, - max_tries - ); - resp - } else { - let error_text = resp - .text() - .await - .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); - let error = - anyhow!("Anthropic Streaming API error {}: {}", status, error_text); - warn!( - "Stream request failed (attempt {}/{}): {}", - attempt + 1, - max_tries, - error - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - } - Err(e) => { - let connect_time = request_start_time.elapsed().as_millis(); - let error = anyhow!("Stream request connection failed: {}", e); - warn!( - "Stream request connection failed: {}ms, attempt {}/{}, error: {}", - connect_time, - attempt + 1, - max_tries, - e - ); - last_error = Some(error); - - if attempt < max_tries - 1 { - let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); - debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - continue; - } - }; - - // Success: create channels and return - let (tx, rx) = mpsc::unbounded_channel(); - let (tx_raw, rx_raw) = mpsc::unbounded_channel(); - - tokio::spawn(handle_anthropic_stream(response, tx, Some(tx_raw))); - - return Ok(StreamResponse { - stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), - raw_sse_rx: Some(rx_raw), - }); + let max_tries = 3; + match ApiFormat::parse(&self.config.format)? { + ApiFormat::OpenAIChat => { + openai::chat::send_stream(self, messages, tools, extra_body, max_tries).await + } + ApiFormat::OpenAIResponses => { + openai::responses::send_stream(self, messages, tools, extra_body, max_tries).await + } + ApiFormat::Anthropic => { + anthropic::request::send_stream(self, messages, tools, extra_body, max_tries).await + } + ApiFormat::Gemini => { + gemini::request::send_stream(self, messages, tools, extra_body, max_tries).await + } } - - let error_msg = format!( - "Stream request failed after {} attempts: {}", - max_tries, - last_error.unwrap_or_else(|| anyhow!("Unknown error")) - ); - error!("{}", error_msg); - Err(anyhow!(error_msg)) } - /// Send a message and wait for the full response (non-streaming) pub async fn send_message( &self, messages: Vec, @@ -1783,7 +97,6 @@ impl AIClient { .await } - /// Send a message and wait for the full response (non-streaming, with extra body overrides) pub async fn send_message_with_extra_body( &self, messages: Vec, @@ -1793,326 +106,24 @@ impl AIClient { let stream_response = self .send_message_stream_with_extra_body(messages, tools, extra_body) .await?; - let mut stream = stream_response.stream; - - let mut full_text = String::new(); - let mut full_reasoning = String::new(); - let mut finish_reason = None; - let mut usage = None; - let mut provider_metadata: Option = None; - - let mut tool_calls: Vec = Vec::new(); - let mut pending_tool_call = PendingToolCall::default(); - - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - let UnifiedResponse { - text, - reasoning_content, - thinking_signature: _, - tool_call, - usage: chunk_usage, - finish_reason: chunk_finish_reason, - provider_metadata: chunk_provider_metadata, - } = chunk; - - if let Some(text) = text { - full_text.push_str(&text); - } - - if let Some(reasoning_content) = reasoning_content { - full_reasoning.push_str(&reasoning_content); - } - - if let Some(tool_call) = tool_call { - let ai_stream_handlers::UnifiedToolCall { - id, - name, - arguments, - } = tool_call; - - if let Some(tool_call_id) = id { - if !tool_call_id.is_empty() { - // Some providers repeat the tool id on every delta. Only switch when the id changes. - let is_new_tool = pending_tool_call.tool_id() != tool_call_id; - if is_new_tool { - if let Some(finalized) = - pending_tool_call.finalize(ToolCallBoundary::NewTool) - { - if finalized.is_error { - warn!( - "[send_message] Dropping invalid tool call at boundary=new_tool: tool_id={}, tool_name={}, raw_len={}", - finalized.tool_id, - finalized.tool_name, - finalized.raw_arguments.len() - ); - } else { - let arguments = finalized.arguments_as_object_map(); - tool_calls.push(ToolCall { - id: finalized.tool_id, - name: finalized.tool_name, - arguments, - }); - } - } - pending_tool_call.start_new(tool_call_id, name.clone()); - debug!( - "[send_message] Detected tool call: {}", - pending_tool_call.tool_name() - ); - } else { - pending_tool_call.update_tool_name_if_missing(name.clone()); - } - } - } - - if let Some(tool_call_arguments) = arguments { - if pending_tool_call.has_pending() { - pending_tool_call.append_arguments(&tool_call_arguments); - } - } - } - - if let Some(finish_reason_) = chunk_finish_reason { - if let Some(finalized) = - pending_tool_call.finalize(ToolCallBoundary::FinishReason) - { - if finalized.is_error { - warn!( - "[send_message] Dropping invalid tool call at boundary=finish_reason: tool_id={}, tool_name={}, raw_len={}", - finalized.tool_id, - finalized.tool_name, - finalized.raw_arguments.len() - ); - } else { - let arguments = finalized.arguments_as_object_map(); - tool_calls.push(ToolCall { - id: finalized.tool_id, - name: finalized.tool_name, - arguments, - }); - } - } - finish_reason = Some(finish_reason_); - } - - if let Some(chunk_usage) = chunk_usage { - usage = Some(Self::unified_usage_to_gemini_usage(chunk_usage)); - } - - if let Some(chunk_provider_metadata) = chunk_provider_metadata { - match provider_metadata.as_mut() { - Some(existing) => { - Self::merge_json_value(existing, chunk_provider_metadata); - } - None => provider_metadata = Some(chunk_provider_metadata), - } - } - } - Err(e) => return Err(e), - } - } - - if let Some(finalized) = pending_tool_call.finalize(ToolCallBoundary::EndOfAggregation) { - if finalized.is_error { - warn!( - "[send_message] Dropping invalid tool call at boundary=end_of_aggregation: tool_id={}, tool_name={}, raw_len={}", - finalized.tool_id, - finalized.tool_name, - finalized.raw_arguments.len() - ); - } else { - let arguments = finalized.arguments_as_object_map(); - tool_calls.push(ToolCall { - id: finalized.tool_id, - name: finalized.tool_name, - arguments, - }); - } - } - - let reasoning_content = if full_reasoning.is_empty() { - None - } else { - Some(full_reasoning) - }; - - let tool_calls_result = if tool_calls.is_empty() { - None - } else { - Some(tool_calls) - }; - - let response = GeminiResponse { - text: full_text, - reasoning_content, - tool_calls: tool_calls_result, - usage, - finish_reason, - provider_metadata, - }; - - Ok(response) + response_aggregator::aggregate_stream_response(stream_response).await } pub async fn test_connection(&self) -> Result { - let start_time = std::time::Instant::now(); - - // Reuse the normal chat request path so the test matches real conversations, even when - // a provider rejects stricter tool_choice settings such as "required". - let test_messages = vec![Message::user( - "Call the get_weather tool for city=Beijing. Do not answer with plain text." - .to_string(), - )]; - let tools = Some(vec![ToolDefinition { - name: "get_weather".to_string(), - description: "Get the weather of a city".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "city": { "type": "string", "description": "The city to get the weather for" } - }, - "required": ["city"], - "additionalProperties": false - }), - }]); - - let result = self.send_message(test_messages, tools).await; - - match result { - Ok(response) => { - let response_time_ms = start_time.elapsed().as_millis() as u64; - if response.tool_calls.is_some() { - Ok(ConnectionTestResult { - success: true, - response_time_ms, - model_response: Some(response.text), - message_code: None, - error_details: None, - }) - } else { - Ok(ConnectionTestResult { - success: true, - response_time_ms, - model_response: Some(response.text), - message_code: Some(ConnectionTestMessageCode::ToolCallsNotDetected), - error_details: None, - }) - } - } - Err(e) => { - let response_time_ms = start_time.elapsed().as_millis() as u64; - let error_msg = format!("{}", e); - debug!("test connection failed: {}", error_msg); - Ok(ConnectionTestResult { - success: false, - response_time_ms, - model_response: None, - message_code: None, - error_details: Some(error_msg), - }) - } - } + healthcheck::test_connection(self).await } pub async fn test_image_input_connection(&self) -> Result { - let start_time = std::time::Instant::now(); - let provider = self.config.format.to_ascii_lowercase(); - let prompt = "Inspect the attached image and reply with exactly one 4-letter code for quadrant colors in TL,TR,BL,BR order using letters R,G,B,Y (R=red, G=green, B=blue, Y=yellow)."; - - let content = if provider == "anthropic" { - serde_json::json!([ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": Self::TEST_IMAGE_PNG_BASE64 - } - }, - { - "type": "text", - "text": prompt - } - ]) - } else { - serde_json::json!([ - { - "type": "image_url", - "image_url": { - "url": format!("data:image/png;base64,{}", Self::TEST_IMAGE_PNG_BASE64) - } - }, - { - "type": "text", - "text": prompt - } - ]) - }; - - let test_messages = vec![Message { - role: "user".to_string(), - content: Some(content.to_string()), - reasoning_content: None, - thinking_signature: None, - tool_calls: None, - tool_call_id: None, - name: None, - is_error: None, - tool_image_attachments: None, - }]; - - match self.send_message(test_messages, None).await { - Ok(response) => { - let matched = Self::image_test_response_matches_expected(&response.text); - - if matched { - Ok(ConnectionTestResult { - success: true, - response_time_ms: start_time.elapsed().as_millis() as u64, - model_response: Some(response.text), - message_code: None, - error_details: None, - }) - } else { - let detail = format!( - "Image understanding verification failed: expected code '{}', got response '{}'", - Self::TEST_IMAGE_EXPECTED_CODE, response.text - ); - debug!("test image input connection failed: {}", detail); - Ok(ConnectionTestResult { - success: false, - response_time_ms: start_time.elapsed().as_millis() as u64, - model_response: Some(response.text), - message_code: Some(ConnectionTestMessageCode::ImageInputCheckFailed), - error_details: Some(detail), - }) - } - } - Err(e) => { - let error_msg = format!("{}", e); - debug!("test image input connection failed: {}", error_msg); - Ok(ConnectionTestResult { - success: false, - response_time_ms: start_time.elapsed().as_millis() as u64, - model_response: None, - message_code: None, - error_details: Some(error_msg), - }) - } - } + healthcheck::test_image_input_connection(self).await } pub async fn list_models(&self) -> Result> { - match self.get_api_format().to_ascii_lowercase().as_str() { - "openai" | "response" | "responses" => self.list_openai_models().await, - "anthropic" => self.list_anthropic_models().await, - format if Self::is_gemini_api_format(format) => self.list_gemini_models().await, - unsupported => Err(anyhow!( - "Listing models is not supported for API format: {}", - unsupported - )), + match ApiFormat::parse(&self.config.format)? { + ApiFormat::OpenAIChat | ApiFormat::OpenAIResponses => { + openai::common::list_models(self).await + } + ApiFormat::Anthropic => anthropic::discovery::list_models(self).await, + ApiFormat::Gemini => gemini::discovery::list_models(self).await, } } } @@ -2120,13 +131,16 @@ impl AIClient { #[cfg(test)] mod tests { use super::AIClient; - use crate::infrastructure::ai::providers::gemini::GeminiMessageConverter; + use crate::infrastructure::ai::providers::{ + anthropic, gemini, gemini::GeminiMessageConverter, openai, + }; + use crate::service::config::types::ReasoningMode; use crate::util::types::{AIConfig, ToolDefinition}; - use serde_json::json; + use serde_json::{json, Value}; - fn make_test_client(format: &str, custom_request_body: Option) -> AIClient { + fn make_test_client(format: &str, custom_request_body: Option) -> AIClient { AIClient::new(AIConfig { - name: "test".to_string(), + name: format!("{}-test", format), base_url: "https://example.com/v1".to_string(), request_url: "https://example.com/v1/chat/completions".to_string(), api_key: "test-key".to_string(), @@ -2136,17 +150,24 @@ mod tests { max_tokens: Some(8192), temperature: None, top_p: None, - enable_thinking_process: false, - support_preserved_thinking: false, + reasoning_mode: ReasoningMode::Default, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, reasoning_effort: None, + thinking_budget_tokens: None, custom_request_body, + custom_request_body_mode: None, }) } + fn make_trim_test_client(format: &str) -> AIClient { + let mut client = make_test_client(format, None); + client.config.custom_request_body_mode = Some("trim".to_string()); + client + } + #[test] fn resolves_openai_models_url_from_completion_endpoint() { let client = AIClient::new(AIConfig { @@ -2160,18 +181,19 @@ mod tests { max_tokens: Some(8192), temperature: None, top_p: None, - enable_thinking_process: false, - support_preserved_thinking: false, + reasoning_mode: ReasoningMode::Default, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, reasoning_effort: None, + thinking_budget_tokens: None, custom_request_body: None, + custom_request_body_mode: None, }); assert_eq!( - client.resolve_openai_models_url(), + openai::common::resolve_models_url(&client), "https://api.openai.com/v1/models" ); } @@ -2189,18 +211,19 @@ mod tests { max_tokens: Some(8192), temperature: None, top_p: None, - enable_thinking_process: false, - support_preserved_thinking: false, + reasoning_mode: ReasoningMode::Default, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, reasoning_effort: None, + thinking_budget_tokens: None, custom_request_body: None, + custom_request_body_mode: None, }); assert_eq!( - client.resolve_anthropic_models_url(), + anthropic::discovery::resolve_models_url(&client), "https://api.anthropic.com/v1/models" ); } @@ -2219,17 +242,19 @@ mod tests { max_tokens: Some(4096), temperature: Some(0.2), top_p: Some(0.8), - enable_thinking_process: true, - support_preserved_thinking: true, + reasoning_mode: ReasoningMode::Enabled, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, reasoning_effort: None, + thinking_budget_tokens: None, custom_request_body: None, + custom_request_body_mode: None, }); - let request_body = client.build_gemini_request_body( + let request_body = gemini::request::build_request_body( + &client, None, vec![json!({ "role": "user", @@ -2298,14 +323,15 @@ mod tests { max_tokens: Some(4096), temperature: None, top_p: None, - enable_thinking_process: false, - support_preserved_thinking: true, + reasoning_mode: ReasoningMode::Default, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, reasoning_effort: None, + thinking_budget_tokens: None, custom_request_body: None, + custom_request_body_mode: None, }); let gemini_tools = GeminiMessageConverter::convert_tools(Some(vec![ToolDefinition { @@ -2319,7 +345,8 @@ mod tests { }), }])); - let request_body = client.build_gemini_request_body( + let request_body = gemini::request::build_request_body( + &client, None, vec![json!({ "role": "user", @@ -2333,6 +360,311 @@ mod tests { assert!(request_body.get("toolConfig").is_none()); } + #[test] + fn build_openai_request_body_uses_generic_thinking_object_when_enabled() { + let client = AIClient::new(AIConfig { + name: "openai-compatible".to_string(), + base_url: "https://example.com/v1".to_string(), + request_url: "https://example.com/v1/chat/completions".to_string(), + api_key: "test-key".to_string(), + model: "test-model".to_string(), + format: "openai".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: None, + top_p: None, + reasoning_mode: ReasoningMode::Enabled, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + thinking_budget_tokens: None, + custom_request_body: None, + custom_request_body_mode: None, + }); + + let request_body = openai::chat::build_request_body( + &client, + &client.config.request_url, + vec![json!({ "role": "user", "content": "hello" })], + None, + None, + ); + + assert_eq!(request_body["thinking"]["type"], "enabled"); + assert!(request_body.get("enable_thinking").is_none()); + assert!(request_body.get("reasoning_split").is_none()); + } + + #[test] + fn build_openai_request_body_uses_enable_thinking_for_siliconflow() { + let client = AIClient::new(AIConfig { + name: "siliconflow".to_string(), + base_url: "https://api.siliconflow.cn/v1".to_string(), + request_url: "https://api.siliconflow.cn/v1/chat/completions".to_string(), + api_key: "test-key".to_string(), + model: "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(), + format: "openai".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: None, + top_p: None, + reasoning_mode: ReasoningMode::Enabled, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + thinking_budget_tokens: None, + custom_request_body: None, + custom_request_body_mode: None, + }); + + let request_body = openai::chat::build_request_body( + &client, + &client.config.request_url, + vec![json!({ "role": "user", "content": "hello" })], + None, + None, + ); + + assert_eq!(request_body["enable_thinking"], true); + assert!(request_body.get("thinking").is_none()); + } + + #[test] + fn build_responses_request_body_maps_disabled_mode_to_none_effort() { + let client = AIClient::new(AIConfig { + name: "responses".to_string(), + base_url: "https://api.openai.com/v1".to_string(), + request_url: "https://api.openai.com/v1/responses".to_string(), + api_key: "test-key".to_string(), + model: "gpt-5".to_string(), + format: "responses".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: None, + top_p: None, + reasoning_mode: ReasoningMode::Disabled, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + thinking_budget_tokens: None, + custom_request_body: None, + custom_request_body_mode: None, + }); + + let request_body = openai::responses::build_request_body( + &client, + Some("Be concise".to_string()), + vec![json!({ + "role": "user", + "content": [{ "type": "input_text", "text": "hello" }] + })], + None, + None, + ); + + assert_eq!(request_body["reasoning"]["effort"], "none"); + } + + #[test] + fn build_anthropic_request_body_uses_adaptive_reasoning_and_effort() { + let client = AIClient::new(AIConfig { + name: "anthropic".to_string(), + base_url: "https://api.anthropic.com".to_string(), + request_url: "https://api.anthropic.com/v1/messages".to_string(), + api_key: "test-key".to_string(), + model: "claude-sonnet-4-6".to_string(), + format: "anthropic".to_string(), + context_window: 200000, + max_tokens: Some(8192), + temperature: None, + top_p: None, + reasoning_mode: ReasoningMode::Adaptive, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: Some("high".to_string()), + thinking_budget_tokens: None, + custom_request_body: None, + custom_request_body_mode: None, + }); + + let request_body = anthropic::request::build_request_body( + &client, + &client.config.request_url, + None, + vec![json!({ "role": "user", "content": [{ "type": "text", "text": "hello" }] })], + None, + None, + ); + + assert_eq!(request_body["thinking"]["type"], "adaptive"); + assert_eq!(request_body["output_config"]["effort"], "high"); + } + + #[test] + fn build_openai_request_body_trim_mode_preserves_essential_fields() { + let mut client = make_trim_test_client("openai"); + client.config.max_tokens = Some(8192); + let messages = vec![json!({ "role": "user", "content": "hello" })]; + + let request_body = openai::chat::build_request_body( + &client, + &client.config.request_url, + messages.clone(), + None, + Some(json!({ + "model": "override-model", + "messages": [{ "role": "user", "content": "override" }], + "stream": false, + "max_tokens": 1, + "temperature": 0.7, + "response_format": { "type": "json_object" } + })), + ); + + assert_eq!(request_body["model"], "test-model"); + assert_eq!(request_body["messages"], json!(messages)); + assert_eq!(request_body["stream"], true); + assert_eq!(request_body["max_tokens"], 8192); + assert_eq!(request_body["temperature"], 0.7); + assert_eq!(request_body["response_format"]["type"], "json_object"); + assert!(request_body.get("thinking").is_none()); + } + + #[test] + fn build_responses_request_body_trim_mode_preserves_essential_fields() { + let mut client = make_trim_test_client("responses"); + client.config.max_tokens = Some(4096); + let input = vec![json!({ + "role": "user", + "content": [{ "type": "input_text", "text": "hello" }] + })]; + + let request_body = openai::responses::build_request_body( + &client, + Some("Be concise".to_string()), + input.clone(), + None, + Some(json!({ + "instructions": "override me", + "input": [{ "role": "user", "content": [{ "type": "input_text", "text": "override" }] }], + "stream": false, + "max_output_tokens": 1, + "temperature": 0.1 + })), + ); + + assert_eq!(request_body["model"], "test-model"); + assert_eq!(request_body["input"], json!(input)); + assert_eq!(request_body["instructions"], "Be concise"); + assert_eq!(request_body["stream"], true); + assert_eq!(request_body["max_output_tokens"], 4096); + assert_eq!(request_body["temperature"], 0.1); + assert!(request_body.get("reasoning").is_none()); + } + + #[test] + fn build_anthropic_request_body_trim_mode_preserves_essential_fields() { + let mut client = make_trim_test_client("anthropic"); + client.config.max_tokens = Some(8192); + let messages = vec![json!({ + "role": "user", + "content": [{ "type": "text", "text": "hello" }] + })]; + + let request_body = anthropic::request::build_request_body( + &client, + &client.config.request_url, + Some("Use the system prompt".to_string()), + messages.clone(), + None, + Some(json!({ + "system": "override me", + "messages": [{ "role": "user", "content": [{ "type": "text", "text": "override" }] }], + "max_tokens": 1, + "stream": false, + "metadata": { "tag": "kept" } + })), + ); + + assert_eq!(request_body["model"], "test-model"); + assert_eq!(request_body["messages"], json!(messages)); + assert_eq!(request_body["system"], "Use the system prompt"); + assert_eq!(request_body["stream"], true); + assert_eq!(request_body["max_tokens"], 8192); + assert_eq!(request_body["metadata"]["tag"], "kept"); + assert!(request_body.get("thinking").is_none()); + } + + #[test] + fn build_gemini_request_body_trim_mode_preserves_essential_fields() { + let mut client = make_trim_test_client("gemini"); + client.config.model = "gemini-2.5-pro".to_string(); + client.config.max_tokens = Some(4096); + + let contents = vec![json!({ + "role": "user", + "parts": [{ "text": "hello" }] + })]; + let system_instruction = json!({ + "parts": [{ "text": "system" }] + }); + let gemini_tools = GeminiMessageConverter::convert_tools(Some(vec![ToolDefinition { + name: "lookup".to_string(), + description: "Look up data".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + }, + "required": ["query"] + }), + }])); + + let request_body = gemini::request::build_request_body( + &client, + Some(system_instruction.clone()), + contents.clone(), + gemini_tools, + Some(json!({ + "contents": [{ "role": "user", "parts": [{ "text": "override" }] }], + "systemInstruction": { "parts": [{ "text": "override system" }] }, + "generationConfig": { + "maxOutputTokens": 1, + "candidateCount": 2 + }, + "tools": [], + "toolConfig": { + "functionCallingConfig": { + "mode": "NONE" + } + }, + "temperature": 0.3 + })), + ); + + assert_eq!(request_body["contents"], json!(contents)); + assert_eq!(request_body["systemInstruction"], system_instruction); + assert_eq!(request_body["generationConfig"]["maxOutputTokens"], 4096); + assert_eq!(request_body["generationConfig"]["candidateCount"], 2); + assert_eq!(request_body["generationConfig"]["temperature"], 0.3); + assert_eq!( + request_body["toolConfig"]["functionCallingConfig"]["mode"], + "AUTO" + ); + assert_eq!( + request_body["tools"][0]["functionDeclarations"][0]["name"], + "lookup" + ); + } + #[test] fn streaming_http_client_does_not_apply_global_request_timeout() { let client = make_test_client("openai", None); diff --git a/src/crates/core/src/infrastructure/ai/client/format.rs b/src/crates/core/src/infrastructure/ai/client/format.rs new file mode 100644 index 000000000..130ee99bb --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/format.rs @@ -0,0 +1,22 @@ +use anyhow::{anyhow, Result}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ApiFormat { + OpenAIChat, + OpenAIResponses, + Anthropic, + Gemini, +} + +impl ApiFormat { + pub(crate) fn parse(value: &str) -> Result { + let normalized = value.trim().to_ascii_lowercase(); + match normalized.as_str() { + "openai" => Ok(Self::OpenAIChat), + "response" | "responses" => Ok(Self::OpenAIResponses), + "anthropic" => Ok(Self::Anthropic), + "gemini" | "google" => Ok(Self::Gemini), + _ => Err(anyhow!("Unknown API format: {}", value)), + } + } +} diff --git a/src/crates/core/src/infrastructure/ai/client/healthcheck.rs b/src/crates/core/src/infrastructure/ai/client/healthcheck.rs new file mode 100644 index 000000000..6c3cd6b07 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/healthcheck.rs @@ -0,0 +1,202 @@ +use crate::infrastructure::ai::client::AIClient; +use crate::util::types::{ + ConnectionTestMessageCode, ConnectionTestResult, Message, ToolDefinition, +}; +use anyhow::Result; +use log::debug; + +pub(crate) fn image_test_response_matches_expected(response: &str) -> bool { + let upper = response.to_ascii_uppercase(); + + let letters_only: String = upper.chars().filter(|c| c.is_ascii_alphabetic()).collect(); + if letters_only.contains(AIClient::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + let tokens: Vec<&str> = upper + .split(|c: char| !c.is_ascii_alphabetic()) + .filter(|s| !s.is_empty()) + .collect(); + + if tokens.contains(&AIClient::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + let single_letter_stream: String = tokens + .iter() + .filter_map(|token| { + if token.len() == 1 { + let ch = token.chars().next()?; + if matches!(ch, 'R' | 'G' | 'B' | 'Y') { + return Some(ch); + } + } + None + }) + .collect(); + if single_letter_stream.contains(AIClient::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + let color_word_stream: String = tokens + .iter() + .filter_map(|token| match *token { + "RED" => Some('R'), + "GREEN" => Some('G'), + "BLUE" => Some('B'), + "YELLOW" => Some('Y'), + _ => None, + }) + .collect(); + if color_word_stream.contains(AIClient::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + let color_letter_stream: String = upper + .chars() + .filter(|c| matches!(*c, 'R' | 'G' | 'B' | 'Y')) + .collect(); + color_letter_stream.contains(AIClient::TEST_IMAGE_EXPECTED_CODE) +} + +pub(crate) async fn test_connection(client: &AIClient) -> Result { + let start_time = std::time::Instant::now(); + + let test_messages = vec![Message::user( + "Call the get_weather tool for city=Beijing. Do not answer with plain text.".to_string(), + )]; + let tools = Some(vec![ToolDefinition { + name: "get_weather".to_string(), + description: "Get the weather of a city".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "The city to get the weather for" } + }, + "required": ["city"], + "additionalProperties": false + }), + }]); + + match client.send_message(test_messages, tools).await { + Ok(response) => { + let response_time_ms = start_time.elapsed().as_millis() as u64; + if response.tool_calls.is_some() { + Ok(ConnectionTestResult { + success: true, + response_time_ms, + model_response: Some(response.text), + message_code: None, + error_details: None, + }) + } else { + Ok(ConnectionTestResult { + success: true, + response_time_ms, + model_response: Some(response.text), + message_code: Some(ConnectionTestMessageCode::ToolCallsNotDetected), + error_details: None, + }) + } + } + Err(e) => { + let response_time_ms = start_time.elapsed().as_millis() as u64; + let error_msg = format!("{}", e); + debug!("test connection failed: {}", error_msg); + Ok(ConnectionTestResult { + success: false, + response_time_ms, + model_response: None, + message_code: None, + error_details: Some(error_msg), + }) + } + } +} + +pub(crate) async fn test_image_input_connection(client: &AIClient) -> Result { + let start_time = std::time::Instant::now(); + let provider = client.config.format.to_ascii_lowercase(); + let prompt = "Inspect the attached image and reply with exactly one 4-letter code for quadrant colors in TL,TR,BL,BR order using letters R,G,B,Y (R=red, G=green, B=blue, Y=yellow)."; + + let content = if provider == "anthropic" { + serde_json::json!([ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": AIClient::TEST_IMAGE_PNG_BASE64 + } + }, + { + "type": "text", + "text": prompt + } + ]) + } else { + serde_json::json!([ + { + "type": "image_url", + "image_url": { + "url": format!("data:image/png;base64,{}", AIClient::TEST_IMAGE_PNG_BASE64) + } + }, + { + "type": "text", + "text": prompt + } + ]) + }; + + let test_messages = vec![Message { + role: "user".to_string(), + content: Some(content.to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + is_error: None, + tool_image_attachments: None, + }]; + + match client.send_message(test_messages, None).await { + Ok(response) => { + if image_test_response_matches_expected(&response.text) { + Ok(ConnectionTestResult { + success: true, + response_time_ms: start_time.elapsed().as_millis() as u64, + model_response: Some(response.text), + message_code: None, + error_details: None, + }) + } else { + let detail = format!( + "Image understanding verification failed: expected code '{}', got response '{}'", + AIClient::TEST_IMAGE_EXPECTED_CODE, + response.text + ); + debug!("test image input connection failed: {}", detail); + Ok(ConnectionTestResult { + success: false, + response_time_ms: start_time.elapsed().as_millis() as u64, + model_response: Some(response.text), + message_code: Some(ConnectionTestMessageCode::ImageInputCheckFailed), + error_details: Some(detail), + }) + } + } + Err(e) => { + let error_msg = format!("{}", e); + debug!("test image input connection failed: {}", error_msg); + Ok(ConnectionTestResult { + success: false, + response_time_ms: start_time.elapsed().as_millis() as u64, + model_response: None, + message_code: None, + error_details: Some(error_msg), + }) + } + } +} diff --git a/src/crates/core/src/infrastructure/ai/client/http.rs b/src/crates/core/src/infrastructure/ai/client/http.rs new file mode 100644 index 000000000..c9efeb144 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/http.rs @@ -0,0 +1,77 @@ +use crate::infrastructure::ai::client::AIClient; +use crate::service::config::ProxyConfig; +use anyhow::{anyhow, Result}; +use log::{debug, error, info, warn}; +use reqwest::{Client, Proxy}; + +pub(crate) fn create_http_client( + proxy_config: Option, + skip_ssl_verify: bool, +) -> Client { + let mut builder = Client::builder() + .connect_timeout(std::time::Duration::from_secs( + AIClient::STREAM_CONNECT_TIMEOUT_SECS, + )) + .user_agent("BitFun/1.0") + .pool_idle_timeout(std::time::Duration::from_secs( + AIClient::HTTP_POOL_IDLE_TIMEOUT_SECS, + )) + .pool_max_idle_per_host(4) + .tcp_keepalive(Some(std::time::Duration::from_secs( + AIClient::HTTP_TCP_KEEPALIVE_SECS, + ))) + .danger_accept_invalid_certs(skip_ssl_verify); + + if skip_ssl_verify { + warn!( + "SSL certificate verification disabled - security risk, use only in test environments" + ); + } + + if let Some(proxy_cfg) = proxy_config { + if proxy_cfg.enabled && !proxy_cfg.url.is_empty() { + match build_proxy(&proxy_cfg) { + Ok(proxy) => { + info!("Using proxy: {}", proxy_cfg.url); + builder = builder.proxy(proxy); + } + Err(e) => { + error!( + "Proxy configuration failed: {}, proceeding without proxy", + e + ); + builder = builder.no_proxy(); + } + } + } else { + builder = builder.no_proxy(); + } + } else { + builder = builder.no_proxy(); + } + + match builder.build() { + Ok(client) => client, + Err(e) => { + error!( + "HTTP client initialization failed: {}, using default client", + e + ); + Client::new() + } + } +} + +fn build_proxy(config: &ProxyConfig) -> Result { + let mut proxy = + Proxy::all(&config.url).map_err(|e| anyhow!("Failed to create proxy: {}", e))?; + + if let (Some(username), Some(password)) = (&config.username, &config.password) { + if !username.is_empty() && !password.is_empty() { + proxy = proxy.basic_auth(username, password); + debug!("Proxy authentication configured for user: {}", username); + } + } + + Ok(proxy) +} diff --git a/src/crates/core/src/infrastructure/ai/client/quirks.rs b/src/crates/core/src/infrastructure/ai/client/quirks.rs new file mode 100644 index 000000000..1678ad0ff --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/quirks.rs @@ -0,0 +1,80 @@ +use crate::service::config::types::ReasoningMode; + +pub(crate) fn is_dashscope_url(url: &str) -> bool { + url.contains("dashscope.aliyuncs.com") +} + +pub(crate) fn is_siliconflow_url(url: &str) -> bool { + url.contains("api.siliconflow.cn") +} + +pub(crate) fn is_minimax_url(url: &str) -> bool { + url.contains("api.minimaxi.com") +} + +pub(crate) fn parse_glm_major_minor(model_name: &str) -> Option<(u32, u32)> { + let lower = model_name.to_ascii_lowercase(); + let tail = lower.strip_prefix("glm-")?; + let mut parts = tail.split('-'); + let version = parts.next()?; + + let mut version_parts = version.split('.'); + let major = version_parts.next()?.parse().ok()?; + let minor = version_parts + .next() + .and_then(|value| value.parse().ok()) + .unwrap_or(0); + + Some((major, minor)) +} + +pub(crate) fn should_append_tool_stream(url: &str, model_name: &str) -> bool { + if url.contains("bigmodel.cn") { + return true; + } + + if !url.contains("aliyuncs.com") { + return false; + } + + parse_glm_major_minor(model_name) + .is_some_and(|(major, minor)| major > 4 || (major == 4 && minor >= 5)) +} + +pub(crate) fn apply_openai_compatible_reasoning_fields( + request_body: &mut serde_json::Value, + mode: ReasoningMode, + url: &str, +) { + let normalized_mode = if mode == ReasoningMode::Adaptive { + ReasoningMode::Enabled + } else { + mode + }; + + if is_dashscope_url(url) || is_siliconflow_url(url) { + if normalized_mode != ReasoningMode::Default { + request_body["enable_thinking"] = + serde_json::json!(normalized_mode == ReasoningMode::Enabled); + } + return; + } + + if is_minimax_url(url) { + if normalized_mode == ReasoningMode::Enabled { + request_body["reasoning_split"] = serde_json::json!(true); + } + return; + } + + match normalized_mode { + ReasoningMode::Default => {} + ReasoningMode::Enabled => { + request_body["thinking"] = serde_json::json!({ "type": "enabled" }); + } + ReasoningMode::Disabled => { + request_body["thinking"] = serde_json::json!({ "type": "disabled" }); + } + ReasoningMode::Adaptive => unreachable!("adaptive mode is normalized above"), + } +} diff --git a/src/crates/core/src/infrastructure/ai/client/response_aggregator.rs b/src/crates/core/src/infrastructure/ai/client/response_aggregator.rs new file mode 100644 index 000000000..c3c022a2f --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/response_aggregator.rs @@ -0,0 +1,174 @@ +use crate::infrastructure::ai::ai_stream_handlers::UnifiedResponse; +use crate::infrastructure::ai::tool_call_accumulator::{PendingToolCall, ToolCallBoundary}; +use crate::util::types::{GeminiResponse, GeminiUsage, ToolCall}; +use anyhow::Result; +use futures::StreamExt; +use log::{debug, warn}; + +use super::StreamResponse; + +pub(crate) async fn aggregate_stream_response( + stream_response: StreamResponse, +) -> Result { + let mut stream = stream_response.stream; + + let mut full_text = String::new(); + let mut full_reasoning = String::new(); + let mut finish_reason = None; + let mut usage = None; + let mut provider_metadata: Option = None; + + let mut tool_calls: Vec = Vec::new(); + let mut pending_tool_call = PendingToolCall::default(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let UnifiedResponse { + text, + reasoning_content, + thinking_signature: _, + tool_call, + usage: chunk_usage, + finish_reason: chunk_finish_reason, + provider_metadata: chunk_provider_metadata, + } = chunk; + + if let Some(text) = text { + full_text.push_str(&text); + } + + if let Some(reasoning_content) = reasoning_content { + full_reasoning.push_str(&reasoning_content); + } + + if let Some(tool_call) = tool_call { + let crate::infrastructure::ai::ai_stream_handlers::UnifiedToolCall { + id, + name, + arguments, + } = tool_call; + + if let Some(tool_call_id) = id { + if !tool_call_id.is_empty() { + let is_new_tool = pending_tool_call.tool_id() != tool_call_id; + if is_new_tool { + if let Some(finalized) = + pending_tool_call.finalize(ToolCallBoundary::NewTool) + { + if finalized.is_error { + warn!( + "[send_message] Dropping invalid tool call at boundary=new_tool: tool_id={}, tool_name={}, raw_len={}", + finalized.tool_id, + finalized.tool_name, + finalized.raw_arguments.len() + ); + } else { + let arguments = finalized.arguments_as_object_map(); + tool_calls.push(ToolCall { + id: finalized.tool_id, + name: finalized.tool_name, + arguments, + }); + } + } + pending_tool_call.start_new(tool_call_id, name.clone()); + debug!( + "[send_message] Detected tool call: {}", + pending_tool_call.tool_name() + ); + } else { + pending_tool_call.update_tool_name_if_missing(name.clone()); + } + } + } + + if let Some(tool_call_arguments) = arguments { + if pending_tool_call.has_pending() { + pending_tool_call.append_arguments(&tool_call_arguments); + } + } + } + + if let Some(finish_reason_) = chunk_finish_reason { + if let Some(finalized) = + pending_tool_call.finalize(ToolCallBoundary::FinishReason) + { + if finalized.is_error { + warn!( + "[send_message] Dropping invalid tool call at boundary=finish_reason: tool_id={}, tool_name={}, raw_len={}", + finalized.tool_id, + finalized.tool_name, + finalized.raw_arguments.len() + ); + } else { + let arguments = finalized.arguments_as_object_map(); + tool_calls.push(ToolCall { + id: finalized.tool_id, + name: finalized.tool_name, + arguments, + }); + } + } + finish_reason = Some(finish_reason_); + } + + if let Some(chunk_usage) = chunk_usage { + usage = Some(unified_usage_to_gemini_usage(chunk_usage)); + } + + if let Some(chunk_provider_metadata) = chunk_provider_metadata { + match provider_metadata.as_mut() { + Some(existing) => { + crate::infrastructure::ai::client::utils::merge_json_value( + existing, + chunk_provider_metadata, + ); + } + None => provider_metadata = Some(chunk_provider_metadata), + } + } + } + Err(e) => return Err(e), + } + } + + if let Some(finalized) = pending_tool_call.finalize(ToolCallBoundary::EndOfAggregation) { + if finalized.is_error { + warn!( + "[send_message] Dropping invalid tool call at boundary=end_of_aggregation: tool_id={}, tool_name={}, raw_len={}", + finalized.tool_id, + finalized.tool_name, + finalized.raw_arguments.len() + ); + } else { + let arguments = finalized.arguments_as_object_map(); + tool_calls.push(ToolCall { + id: finalized.tool_id, + name: finalized.tool_name, + arguments, + }); + } + } + + Ok(GeminiResponse { + text: full_text, + reasoning_content: (!full_reasoning.is_empty()).then_some(full_reasoning), + tool_calls: (!tool_calls.is_empty()).then_some(tool_calls), + usage, + finish_reason, + provider_metadata, + }) +} + +pub(crate) fn unified_usage_to_gemini_usage( + usage: crate::infrastructure::ai::ai_stream_handlers::UnifiedTokenUsage, +) -> GeminiUsage { + GeminiUsage { + prompt_token_count: usage.prompt_token_count, + candidates_token_count: usage.candidates_token_count, + total_token_count: usage.total_token_count, + reasoning_token_count: usage.reasoning_token_count, + cached_content_token_count: usage.cached_content_token_count, + } +} diff --git a/src/crates/core/src/infrastructure/ai/client/sse.rs b/src/crates/core/src/infrastructure/ai/client/sse.rs new file mode 100644 index 000000000..535cbe908 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/sse.rs @@ -0,0 +1,128 @@ +use crate::infrastructure::ai::ai_stream_handlers::UnifiedResponse; +use crate::infrastructure::ai::client::StreamResponse; +use anyhow::{anyhow, Result}; +use log::{debug, error, warn}; +use tokio::sync::mpsc; + +pub(crate) async fn execute_sse_request( + label: &str, + _url: &str, + request_body: &serde_json::Value, + max_tries: usize, + build_request: BuildRequest, + spawn_handler: SpawnHandler, +) -> Result +where + BuildRequest: Fn() -> reqwest::RequestBuilder, + SpawnHandler: Fn( + reqwest::Response, + mpsc::UnboundedSender>, + Option>, + ), +{ + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let response_result = build_request().json(request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!("{} client error {}: {}", label, status, error_text); + return Err(anyhow!("{} client error {}: {}", label, status, error_text)); + } + + if status.is_success() { + debug!( + "{} request connected: {}ms, status: {}, attempt: {}/{}", + label, + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = anyhow!("{} error {}: {}", label, status, error_text); + warn!( + "{} request failed: {}ms, attempt {}/{}, error: {}", + label, + connect_time, + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying {} after {}ms (attempt {})", + label, + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("{} connection failed: {}", label, e); + warn!( + "{} connection failed: {}ms, attempt {}/{}, error: {}", + label, + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying {} after {}ms (attempt {})", + label, + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + spawn_handler(response, tx, Some(tx_raw)); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "{} failed after {} attempts: {}", + label, + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) +} diff --git a/src/crates/core/src/infrastructure/ai/client/utils.rs b/src/crates/core/src/infrastructure/ai/client/utils.rs new file mode 100644 index 000000000..768ec0f1d --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/client/utils.rs @@ -0,0 +1,82 @@ +use crate::util::types::{AIConfig, RemoteModelInfo}; + +pub(crate) fn merge_json_value(target: &mut serde_json::Value, overlay: serde_json::Value) { + match (target, overlay) { + (serde_json::Value::Object(target_map), serde_json::Value::Object(overlay_map)) => { + for (key, value) in overlay_map { + let entry = target_map.entry(key).or_insert(serde_json::Value::Null); + merge_json_value(entry, value); + } + } + (target_slot, overlay_value) => { + *target_slot = overlay_value; + } + } +} + +pub(crate) fn is_trim_custom_request_body_mode(config: &AIConfig) -> bool { + config.custom_request_body_mode.as_deref() == Some("trim") +} + +pub(crate) fn build_request_body_subset( + source: &serde_json::Value, + top_level_keys: &[&str], + nested_fields: &[(&str, &str)], +) -> serde_json::Value { + let mut subset = serde_json::Map::new(); + + if let Some(source_obj) = source.as_object() { + for key in top_level_keys { + if let Some(value) = source_obj.get(*key) { + subset.insert((*key).to_string(), value.clone()); + } + } + } + + for (parent, child) in nested_fields { + let Some(child_value) = source + .get(*parent) + .and_then(serde_json::Value::as_object) + .and_then(|parent_obj| parent_obj.get(*child)) + .cloned() + else { + continue; + }; + + let parent_entry = subset + .entry((*parent).to_string()) + .or_insert_with(|| serde_json::json!({})); + + if !parent_entry.is_object() { + *parent_entry = serde_json::json!({}); + } + + parent_entry + .as_object_mut() + .expect("protected request subset parent must be object") + .insert((*child).to_string(), child_value); + } + + serde_json::Value::Object(subset) +} + +pub(crate) fn dedupe_remote_models(models: Vec) -> Vec { + let mut seen = std::collections::HashSet::new(); + let mut deduped = Vec::new(); + + for model in models { + if seen.insert(model.id.clone()) { + deduped.push(model); + } + } + + deduped +} + +pub(crate) fn normalize_base_url_for_discovery(base_url: &str) -> String { + base_url + .trim() + .trim_end_matches('#') + .trim_end_matches('/') + .to_string() +} diff --git a/src/crates/core/src/infrastructure/ai/providers/anthropic/discovery.rs b/src/crates/core/src/infrastructure/ai/providers/anthropic/discovery.rs new file mode 100644 index 000000000..dd2fa5459 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/anthropic/discovery.rs @@ -0,0 +1,64 @@ +use crate::infrastructure::ai::client::utils::{ + dedupe_remote_models, normalize_base_url_for_discovery, +}; +use crate::infrastructure::ai::client::AIClient; +use crate::util::types::RemoteModelInfo; +use anyhow::Result; +use serde::Deserialize; + +use super::request::apply_headers; + +#[derive(Debug, Deserialize)] +struct AnthropicModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct AnthropicModelEntry { + id: String, + #[serde(default)] + display_name: Option, +} + +pub(crate) fn resolve_models_url(client: &AIClient) -> String { + let mut base = normalize_base_url_for_discovery(&client.config.base_url); + + if base.ends_with("/v1/messages") { + base.truncate(base.len() - "/v1/messages".len()); + return format!("{}/v1/models", base); + } + + if base.ends_with("/v1/models") { + return base; + } + + if base.ends_with("/v1") { + return format!("{}/models", base); + } + + if base.is_empty() { + return "v1/models".to_string(); + } + + format!("{}/v1/models", base) +} + +pub(crate) async fn list_models(client: &AIClient) -> Result> { + let url = resolve_models_url(client); + let response = apply_headers(client, client.client.get(&url), &url) + .send() + .await? + .error_for_status()?; + + let payload: AnthropicModelsResponse = response.json().await?; + Ok(dedupe_remote_models( + payload + .data + .into_iter() + .map(|model| RemoteModelInfo { + id: model.id, + display_name: model.display_name, + }) + .collect(), + )) +} diff --git a/src/crates/core/src/infrastructure/ai/providers/anthropic/mod.rs b/src/crates/core/src/infrastructure/ai/providers/anthropic/mod.rs index e01d67102..54e06d976 100644 --- a/src/crates/core/src/infrastructure/ai/providers/anthropic/mod.rs +++ b/src/crates/core/src/infrastructure/ai/providers/anthropic/mod.rs @@ -2,6 +2,8 @@ //! //! Implements interaction with Anthropic Claude models +pub mod discovery; pub mod message_converter; +pub mod request; pub use message_converter::AnthropicMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/anthropic/request.rs b/src/crates/core/src/infrastructure/ai/providers/anthropic/request.rs new file mode 100644 index 000000000..73763292a --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/anthropic/request.rs @@ -0,0 +1,233 @@ +use super::AnthropicMessageConverter; +use crate::infrastructure::ai::ai_stream_handlers::handle_anthropic_stream; +use crate::infrastructure::ai::client::quirks::should_append_tool_stream; +use crate::infrastructure::ai::client::sse::execute_sse_request; +use crate::infrastructure::ai::client::{AIClient, StreamResponse}; +use crate::infrastructure::ai::providers::shared; +use crate::service::config::types::ReasoningMode; +use crate::util::types::{Message, ToolDefinition}; +use anyhow::Result; +use log::{debug, warn}; +use reqwest::RequestBuilder; + +pub(crate) fn apply_headers( + client: &AIClient, + builder: RequestBuilder, + url: &str, +) -> RequestBuilder { + shared::apply_header_policy(client, builder, |mut builder| { + builder = builder.header("Content-Type", "application/json"); + + if url.contains("bigmodel.cn") { + builder = builder.header("Authorization", format!("Bearer {}", client.config.api_key)); + } else { + builder = builder + .header("x-api-key", &client.config.api_key) + .header("anthropic-version", "2023-06-01"); + } + + if url.contains("openbitfun.com") { + builder = builder.header("X-Verification-Code", "from_bitfun"); + } + + builder + }) +} + +fn anthropic_supports_adaptive_reasoning(model_name: &str) -> bool { + matches!( + model_name, + name if name.starts_with("claude-opus-4-6") + || name.starts_with("claude-sonnet-4-6") + || name.starts_with("claude-mythos") + ) +} + +fn anthropic_supports_thinking_budget(model_name: &str) -> bool { + model_name.starts_with("claude") +} + +fn default_anthropic_budget_tokens(max_tokens: Option) -> Option { + max_tokens.map(|value| 10_000u32.min(value.saturating_mul(3) / 4)) +} + +fn apply_reasoning_fields( + request_body: &mut serde_json::Value, + mode: ReasoningMode, + model_name: &str, + max_tokens: Option, + reasoning_effort: Option<&str>, + thinking_budget_tokens: Option, +) { + match mode { + ReasoningMode::Default => {} + ReasoningMode::Disabled => { + request_body["thinking"] = serde_json::json!({ "type": "disabled" }); + } + ReasoningMode::Enabled => { + let mut thinking = serde_json::json!({ "type": "enabled" }); + if anthropic_supports_thinking_budget(model_name) { + if let Some(budget_tokens) = + thinking_budget_tokens.or_else(|| default_anthropic_budget_tokens(max_tokens)) + { + thinking["budget_tokens"] = serde_json::json!(budget_tokens); + } + } + request_body["thinking"] = thinking; + } + ReasoningMode::Adaptive => { + if anthropic_supports_adaptive_reasoning(model_name) { + request_body["thinking"] = serde_json::json!({ "type": "adaptive" }); + if let Some(effort) = reasoning_effort.filter(|value| !value.trim().is_empty()) { + request_body["output_config"] = serde_json::json!({ + "effort": effort + }); + } + } else { + warn!( + target: "ai::anthropic_stream_request", + "Model {} does not advertise Anthropic adaptive reasoning support; falling back to manual thinking", + model_name + ); + apply_reasoning_fields( + request_body, + ReasoningMode::Enabled, + model_name, + max_tokens, + None, + thinking_budget_tokens, + ); + } + } + } + + if mode != ReasoningMode::Adaptive + && reasoning_effort.is_some_and(|value| !value.trim().is_empty()) + { + warn!( + target: "ai::anthropic_stream_request", + "Ignoring reasoning_effort for Anthropic model {} because effort currently applies only to adaptive reasoning mode", + model_name + ); + } +} + +pub(crate) fn build_request_body( + client: &AIClient, + url: &str, + system_message: Option, + anthropic_messages: Vec, + anthropic_tools: Option>, + extra_body: Option, +) -> serde_json::Value { + let max_tokens = client.config.max_tokens.unwrap_or(8192); + + let mut request_body = serde_json::json!({ + "model": client.config.model, + "messages": anthropic_messages, + "max_tokens": max_tokens, + "stream": true + }); + + let model_name = client.config.model.to_lowercase(); + + if should_append_tool_stream(url, &model_name) { + request_body["tool_stream"] = serde_json::Value::Bool(true); + } + + apply_reasoning_fields( + &mut request_body, + client.config.reasoning_mode, + &model_name, + Some(max_tokens), + client.config.reasoning_effort.as_deref(), + client.config.thinking_budget_tokens, + ); + + if let Some(system) = system_message { + request_body["system"] = serde_json::Value::String(system); + } + + let protected_body = shared::protect_request_body( + client, + &mut request_body, + &[ + "model", + "messages", + "max_tokens", + "stream", + "system", + "tool_stream", + ], + &[], + ); + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + shared::merge_extra_body(&mut request_body, extra_obj); + shared::log_extra_body_keys("ai::anthropic_stream_request", extra_obj); + } + } + + shared::restore_protected_body(&mut request_body, protected_body); + + shared::log_request_body( + "ai::anthropic_stream_request", + "Anthropic stream request body (excluding tools):", + &request_body, + ); + + if let Some(tools) = anthropic_tools { + let tool_names = tools + .iter() + .map(|tool| { + shared::extract_top_level_string_field(tool, "name") + .unwrap_or_else(|| "unknown".to_string()) + }) + .collect::>(); + shared::log_tool_names("ai::anthropic_stream_request", tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + } + } + + request_body +} + +pub(crate) async fn send_stream( + client: &AIClient, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, +) -> Result { + let url = client.config.request_url.clone(); + debug!( + "Anthropic config: model={}, request_url={}, max_tries={}", + client.config.model, client.config.request_url, max_tries + ); + + let (system_message, anthropic_messages) = + AnthropicMessageConverter::convert_messages(messages); + let anthropic_tools = AnthropicMessageConverter::convert_tools(tools); + let request_body = build_request_body( + client, + &url, + system_message, + anthropic_messages, + anthropic_tools, + extra_body, + ); + + execute_sse_request( + "Anthropic Streaming API", + &url, + &request_body, + max_tries, + || apply_headers(client, client.client.post(&url), &url), + move |response, tx, tx_raw| { + tokio::spawn(handle_anthropic_stream(response, tx, tx_raw)); + }, + ) + .await +} diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/discovery.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/discovery.rs new file mode 100644 index 000000000..9423c6dc4 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/discovery.rs @@ -0,0 +1,73 @@ +use crate::infrastructure::ai::client::utils::dedupe_remote_models; +use crate::infrastructure::ai::client::AIClient; +use crate::util::types::RemoteModelInfo; +use anyhow::Result; +use log::debug; +use serde::Deserialize; + +use super::request::{apply_headers, gemini_base_url}; + +#[derive(Debug, Deserialize)] +struct GeminiModelsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiModelEntry { + name: String, + #[serde(default)] + display_name: Option, + #[serde(default, deserialize_with = "deserialize_null_as_default")] + supported_generation_methods: Vec, +} + +fn deserialize_null_as_default<'de, D, T>(deserializer: D) -> std::result::Result +where + D: serde::Deserializer<'de>, + T: Default + serde::Deserialize<'de>, +{ + Option::::deserialize(deserializer).map(|value| value.unwrap_or_default()) +} + +pub(crate) fn resolve_models_url(client: &AIClient) -> String { + let base = gemini_base_url(&client.config.base_url); + format!("{}/v1beta/models", base) +} + +pub(crate) async fn list_models(client: &AIClient) -> Result> { + let url = resolve_models_url(client); + debug!("Gemini models list URL: {}", url); + + let response = apply_headers(client, client.client.get(&url)) + .send() + .await? + .error_for_status()?; + + let payload: GeminiModelsResponse = response.json().await?; + Ok(dedupe_remote_models( + payload + .models + .into_iter() + .filter(|model| { + model.supported_generation_methods.is_empty() + || model + .supported_generation_methods + .iter() + .any(|method| method == "generateContent") + }) + .map(|model| { + let id = model + .name + .strip_prefix("models/") + .unwrap_or(&model.name) + .to_string(); + RemoteModelInfo { + id, + display_name: model.display_name, + } + }) + .collect(), + )) +} diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs index ee6d89d2e..c038cd969 100644 --- a/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs @@ -1,5 +1,7 @@ //! Gemini provider module +pub mod discovery; pub mod message_converter; +pub mod request; pub use message_converter::GeminiMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/request.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/request.rs new file mode 100644 index 000000000..9ceee3ec5 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/request.rs @@ -0,0 +1,348 @@ +use super::GeminiMessageConverter; +use crate::infrastructure::ai::ai_stream_handlers::handle_gemini_stream; +use crate::infrastructure::ai::client::sse::execute_sse_request; +use crate::infrastructure::ai::client::{AIClient, StreamResponse}; +use crate::infrastructure::ai::providers::shared; +use crate::service::config::types::ReasoningMode; +use crate::util::types::{Message, ToolDefinition}; +use anyhow::Result; +use log::debug; +use reqwest::RequestBuilder; + +pub(crate) fn apply_headers(client: &AIClient, builder: RequestBuilder) -> RequestBuilder { + shared::apply_header_policy(client, builder, |mut builder| { + builder = builder + .header("Content-Type", "application/json") + .header("x-goog-api-key", &client.config.api_key) + .header("Authorization", format!("Bearer {}", client.config.api_key)); + + if client.config.base_url.contains("openbitfun.com") { + builder = builder.header("X-Verification-Code", "from_bitfun"); + } + + builder + }) +} + +pub(crate) fn gemini_base_url(url: &str) -> &str { + let mut value = url.trim().trim_end_matches('/'); + if let Some(pos) = value.find("/v1beta") { + value = &value[..pos]; + } + if let Some(pos) = value.find("/models/") { + value = &value[..pos]; + } + value.trim_end_matches('/') +} + +pub(crate) fn resolve_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/'); + if trimmed.is_empty() { + return String::new(); + } + + let base = gemini_base_url(trimmed); + let encoded_model = urlencoding::encode(model_name.trim()); + format!( + "{}/v1beta/models/{}:streamGenerateContent?alt=sse", + base, encoded_model + ) +} + +fn apply_reasoning_fields(request_body: &mut serde_json::Value, mode: ReasoningMode) { + if matches!(mode, ReasoningMode::Enabled | ReasoningMode::Adaptive) { + insert_generation_field( + request_body, + "thinkingConfig", + serde_json::json!({ + "includeThoughts": true, + }), + ); + } +} + +fn ensure_generation_config( + request_body: &mut serde_json::Value, +) -> &mut serde_json::Map { + if !request_body + .get("generationConfig") + .is_some_and(serde_json::Value::is_object) + { + request_body["generationConfig"] = serde_json::json!({}); + } + + request_body["generationConfig"] + .as_object_mut() + .expect("generationConfig must be an object") +} + +fn insert_generation_field( + request_body: &mut serde_json::Value, + key: &str, + value: serde_json::Value, +) { + ensure_generation_config(request_body).insert(key.to_string(), value); +} + +fn normalize_stop_sequences(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::String(sequence) => { + Some(serde_json::Value::Array(vec![serde_json::Value::String( + sequence.clone(), + )])) + } + serde_json::Value::Array(items) => { + let sequences = items + .iter() + .filter_map(|item| item.as_str().map(|sequence| sequence.to_string())) + .map(serde_json::Value::String) + .collect::>(); + + if sequences.is_empty() { + None + } else { + Some(serde_json::Value::Array(sequences)) + } + } + _ => None, + } +} + +fn apply_response_format_translation( + request_body: &mut serde_json::Value, + response_format: &serde_json::Value, +) -> bool { + match response_format { + serde_json::Value::String(kind) if matches!(kind.as_str(), "json" | "json_object") => { + insert_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + true + } + serde_json::Value::Object(map) => { + let Some(kind) = map.get("type").and_then(serde_json::Value::as_str) else { + return false; + }; + + match kind { + "json" | "json_object" => { + insert_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + true + } + "json_schema" => { + insert_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + + if let Some(schema) = map + .get("json_schema") + .and_then(serde_json::Value::as_object) + .and_then(|json_schema| json_schema.get("schema")) + .or_else(|| map.get("schema")) + { + insert_generation_field( + request_body, + "responseJsonSchema", + GeminiMessageConverter::sanitize_schema(schema.clone()), + ); + } + + true + } + _ => false, + } + } + _ => false, + } +} + +fn translate_extra_body( + request_body: &mut serde_json::Value, + extra_obj: &mut serde_json::Map, +) { + if let Some(max_tokens) = extra_obj.remove("max_tokens") { + insert_generation_field(request_body, "maxOutputTokens", max_tokens); + } + + if let Some(temperature) = extra_obj.remove("temperature") { + insert_generation_field(request_body, "temperature", temperature); + } + + let top_p = extra_obj + .remove("top_p") + .or_else(|| extra_obj.remove("topP")); + if let Some(top_p) = top_p { + insert_generation_field(request_body, "topP", top_p); + } + + if let Some(stop_sequences) = extra_obj.get("stop").and_then(normalize_stop_sequences) { + extra_obj.remove("stop"); + insert_generation_field(request_body, "stopSequences", stop_sequences); + } + + if let Some(response_mime_type) = extra_obj + .remove("responseMimeType") + .or_else(|| extra_obj.remove("response_mime_type")) + { + insert_generation_field(request_body, "responseMimeType", response_mime_type); + } + + if let Some(response_schema) = extra_obj + .remove("responseJsonSchema") + .or_else(|| extra_obj.remove("responseSchema")) + .or_else(|| extra_obj.remove("response_schema")) + { + insert_generation_field( + request_body, + "responseJsonSchema", + GeminiMessageConverter::sanitize_schema(response_schema), + ); + } + + if let Some(response_format) = extra_obj.get("response_format").cloned() { + if apply_response_format_translation(request_body, &response_format) { + extra_obj.remove("response_format"); + } + } +} + +pub(crate) fn build_request_body( + client: &AIClient, + system_instruction: Option, + contents: Vec, + gemini_tools: Option>, + extra_body: Option, +) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "contents": contents, + }); + + if let Some(system_instruction) = system_instruction { + request_body["systemInstruction"] = system_instruction; + } + + if let Some(max_tokens) = client.config.max_tokens { + insert_generation_field( + &mut request_body, + "maxOutputTokens", + serde_json::json!(max_tokens), + ); + } + + if let Some(temperature) = client.config.temperature { + insert_generation_field( + &mut request_body, + "temperature", + serde_json::json!(temperature), + ); + } + + if let Some(top_p) = client.config.top_p { + insert_generation_field(&mut request_body, "topP", serde_json::json!(top_p)); + } + + apply_reasoning_fields(&mut request_body, client.config.reasoning_mode); + + if let Some(tools) = gemini_tools { + let tool_names = tools + .iter() + .flat_map(shared::collect_function_declaration_names_or_object_keys) + .collect::>(); + shared::log_tool_names("ai::gemini_stream_request", tool_names); + + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + let has_function_declarations = request_body["tools"] + .as_array() + .map(|tools| { + tools + .iter() + .any(|tool| tool.get("functionDeclarations").is_some()) + }) + .unwrap_or(false); + + if has_function_declarations { + request_body["toolConfig"] = serde_json::json!({ + "functionCallingConfig": { + "mode": "AUTO" + } + }); + } + } + } + + let protected_body = shared::protect_request_body( + client, + &mut request_body, + &["contents", "systemInstruction", "tools", "toolConfig"], + &[("generationConfig", "maxOutputTokens")], + ); + + if let Some(extra) = extra_body { + if let Some(mut extra_obj) = extra.as_object().cloned() { + translate_extra_body(&mut request_body, &mut extra_obj); + let override_keys = extra_obj.keys().cloned().collect::>(); + shared::merge_extra_body_recursively(&mut request_body, extra_obj); + debug!( + target: "ai::gemini_stream_request", + "Applied extra_body overrides: {:?}", + override_keys + ); + } + } + + shared::restore_protected_body(&mut request_body, protected_body); + + shared::log_request_body( + "ai::gemini_stream_request", + "Gemini stream request body:", + &request_body, + ); + + request_body +} + +pub(crate) async fn send_stream( + client: &AIClient, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, +) -> Result { + let url = resolve_request_url(&client.config.request_url, &client.config.model); + debug!( + "Gemini config: model={}, request_url={}, max_tries={}", + client.config.model, url, max_tries + ); + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, &client.config.model); + let gemini_tools = GeminiMessageConverter::convert_tools(tools); + let request_body = build_request_body( + client, + system_instruction, + contents, + gemini_tools, + extra_body, + ); + + execute_sse_request( + "Gemini Streaming API", + &url, + &request_body, + max_tries, + || apply_headers(client, client.client.post(&url)), + move |response, tx, tx_raw| { + tokio::spawn(handle_gemini_stream(response, tx, tx_raw)); + }, + ) + .await +} diff --git a/src/crates/core/src/infrastructure/ai/providers/mod.rs b/src/crates/core/src/infrastructure/ai/providers/mod.rs index 452cfabc7..5d3923162 100644 --- a/src/crates/core/src/infrastructure/ai/providers/mod.rs +++ b/src/crates/core/src/infrastructure/ai/providers/mod.rs @@ -5,6 +5,7 @@ pub mod anthropic; pub mod gemini; pub mod openai; +pub(crate) mod shared; pub use anthropic::AnthropicMessageConverter; pub use gemini::GeminiMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/chat.rs b/src/crates/core/src/infrastructure/ai/providers/openai/chat.rs new file mode 100644 index 000000000..d9b77cb7c --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/openai/chat.rs @@ -0,0 +1,107 @@ +use super::{common, OpenAIMessageConverter}; +use crate::infrastructure::ai::ai_stream_handlers::handle_openai_stream; +use crate::infrastructure::ai::client::quirks::should_append_tool_stream; +use crate::infrastructure::ai::client::sse::execute_sse_request; +use crate::infrastructure::ai::client::{AIClient, StreamResponse}; +use crate::infrastructure::ai::providers::shared; +use crate::util::types::{Message, ToolDefinition}; +use anyhow::Result; +use log::{debug, warn}; + +pub(crate) fn build_request_body( + client: &AIClient, + url: &str, + openai_messages: Vec, + openai_tools: Option>, + extra_body: Option, +) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "model": client.config.model, + "messages": openai_messages, + "stream": true + }); + + let model_name = client.config.model.to_lowercase(); + + if should_append_tool_stream(url, &model_name) { + request_body["tool_stream"] = serde_json::Value::Bool(true); + } + + common::apply_reasoning_fields(&mut request_body, client, url); + + if let Some(max_tokens) = client.config.max_tokens { + request_body["max_tokens"] = serde_json::json!(max_tokens); + } + + let protected_body = shared::protect_request_body( + client, + &mut request_body, + &["model", "messages", "stream", "max_tokens", "tool_stream"], + &[], + ); + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + shared::merge_extra_body(&mut request_body, extra_obj); + shared::log_extra_body_keys("ai::openai_stream_request", extra_obj); + } + } + + shared::restore_protected_body(&mut request_body, protected_body); + + if let Some(request_obj) = request_body.as_object_mut() { + if let Some(existing_n) = request_obj.remove("n") { + warn!( + target: "ai::openai_stream_request", + "Removed custom request field n={} because the stream processor only handles the first choice", + existing_n + ); + } + } + + shared::log_request_body( + "ai::openai_stream_request", + "OpenAI stream request body (excluding tools):", + &request_body, + ); + + common::attach_tools(&mut request_body, openai_tools, "ai::openai_stream_request"); + + request_body +} + +pub(crate) async fn send_stream( + client: &AIClient, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, +) -> Result { + let url = client.config.request_url.clone(); + debug!( + "OpenAI config: model={}, request_url={}, max_tries={}", + client.config.model, client.config.request_url, max_tries + ); + + let openai_messages = OpenAIMessageConverter::convert_messages(messages); + let openai_tools = OpenAIMessageConverter::convert_tools(tools); + let request_body = build_request_body(client, &url, openai_messages, openai_tools, extra_body); + let inline_think_in_text = client.config.inline_think_in_text; + + execute_sse_request( + "OpenAI Streaming API", + &url, + &request_body, + max_tries, + || common::apply_headers(client, client.client.post(&url)), + move |response, tx, tx_raw| { + tokio::spawn(handle_openai_stream( + response, + tx, + tx_raw, + inline_think_in_text, + )); + }, + ) + .await +} diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/common.rs b/src/crates/core/src/infrastructure/ai/providers/openai/common.rs new file mode 100644 index 000000000..bfcf75ea5 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/openai/common.rs @@ -0,0 +1,107 @@ +use crate::infrastructure::ai::client::quirks::apply_openai_compatible_reasoning_fields; +use crate::infrastructure::ai::client::utils::{ + dedupe_remote_models, normalize_base_url_for_discovery, +}; +use crate::infrastructure::ai::client::AIClient; +use crate::infrastructure::ai::providers::shared; +use crate::util::types::RemoteModelInfo; +use anyhow::Result; +use reqwest::RequestBuilder; +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +struct OpenAIModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAIModelEntry { + id: String, +} + +pub(crate) fn apply_headers(client: &AIClient, builder: RequestBuilder) -> RequestBuilder { + shared::apply_header_policy(client, builder, |mut builder| { + builder = builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", client.config.api_key)); + + if client.config.base_url.contains("openbitfun.com") { + builder = builder.header("X-Verification-Code", "from_bitfun"); + } + + builder + }) +} + +pub(crate) fn apply_reasoning_fields( + request_body: &mut serde_json::Value, + client: &AIClient, + url: &str, +) { + apply_openai_compatible_reasoning_fields(request_body, client.config.reasoning_mode, url); +} + +pub(crate) fn resolve_models_url(client: &AIClient) -> String { + let mut base = normalize_base_url_for_discovery(&client.config.base_url); + + for suffix in ["/chat/completions", "/responses", "/models"] { + if base.ends_with(suffix) { + base.truncate(base.len() - suffix.len()); + break; + } + } + + if base.is_empty() { + return "models".to_string(); + } + + format!("{}/models", base) +} + +pub(crate) async fn list_models(client: &AIClient) -> Result> { + let url = resolve_models_url(client); + let response = apply_headers(client, client.client.get(&url)) + .send() + .await? + .error_for_status()?; + + let payload: OpenAIModelsResponse = response.json().await?; + Ok(dedupe_remote_models( + payload + .data + .into_iter() + .map(|model| RemoteModelInfo { + id: model.id, + display_name: None, + }) + .collect(), + )) +} + +pub(crate) fn extract_tool_name(tool: &serde_json::Value) -> String { + tool.get("function") + .and_then(|function| function.get("name")) + .and_then(|name| name.as_str()) + .unwrap_or("unknown") + .to_string() +} + +pub(crate) fn attach_tools( + request_body: &mut serde_json::Value, + tools: Option>, + target: &str, +) { + if let Some(tools) = tools { + let tool_names = tools.iter().map(extract_tool_name).collect::>(); + shared::log_tool_names(target, tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + let has_tool_choice = request_body + .get("tool_choice") + .is_some_and(|value| !value.is_null()); + if !has_tool_choice { + request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + } + } + } +} diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs b/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs index d0760a257..8f9a5312d 100644 --- a/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs +++ b/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs @@ -309,6 +309,10 @@ impl OpenAIMessageConverter { if let Some(reasoning) = msg.reasoning_content { if !reasoning.is_empty() { + // Official OpenAI Chat Completions may ignore replayed reasoning_content, but + // many OpenAI-compatible providers require it to continue interleaved thinking. + // Replaying it here is therefore the compatibility default; at worst this only + // adds transport cost for providers that ignore the field. openai_msg["reasoning_content"] = Value::String(reasoning); } } diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/mod.rs b/src/crates/core/src/infrastructure/ai/providers/openai/mod.rs index 44ad1060b..63258daca 100644 --- a/src/crates/core/src/infrastructure/ai/providers/openai/mod.rs +++ b/src/crates/core/src/infrastructure/ai/providers/openai/mod.rs @@ -1,5 +1,8 @@ //! OpenAI provider module +pub mod chat; +pub mod common; pub mod message_converter; +pub mod responses; pub use message_converter::OpenAIMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/responses.rs b/src/crates/core/src/infrastructure/ai/providers/openai/responses.rs new file mode 100644 index 000000000..9645abfe6 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/openai/responses.rs @@ -0,0 +1,124 @@ +use super::{common, OpenAIMessageConverter}; +use crate::infrastructure::ai::ai_stream_handlers::handle_responses_stream; +use crate::infrastructure::ai::client::sse::execute_sse_request; +use crate::infrastructure::ai::client::{AIClient, StreamResponse}; +use crate::infrastructure::ai::providers::shared; +use crate::service::config::types::ReasoningMode; +use crate::util::types::{Message, ToolDefinition}; +use anyhow::Result; +use log::debug; + +pub(crate) fn build_request_body( + client: &AIClient, + instructions: Option, + response_input: Vec, + openai_tools: Option>, + extra_body: Option, +) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "model": client.config.model, + "input": response_input, + "stream": true + }); + + if let Some(instructions) = instructions.filter(|value| !value.trim().is_empty()) { + request_body["instructions"] = serde_json::Value::String(instructions); + } + + if let Some(max_tokens) = client.config.max_tokens { + request_body["max_output_tokens"] = serde_json::json!(max_tokens); + } + + let responses_effort = client + .config + .reasoning_effort + .as_deref() + .filter(|value| !value.trim().is_empty()) + .map(str::to_string) + .or_else(|| { + if client.config.reasoning_mode == ReasoningMode::Disabled { + Some("none".to_string()) + } else { + None + } + }); + + if let Some(effort) = responses_effort { + request_body["reasoning"] = serde_json::json!({ + "effort": effort + }); + } + + let protected_body = shared::protect_request_body( + client, + &mut request_body, + &[ + "model", + "input", + "instructions", + "stream", + "max_output_tokens", + ], + &[], + ); + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + shared::merge_extra_body(&mut request_body, extra_obj); + shared::log_extra_body_keys("ai::responses_stream_request", extra_obj); + } + } + + shared::restore_protected_body(&mut request_body, protected_body); + + shared::log_request_body( + "ai::responses_stream_request", + "Responses stream request body (excluding tools):", + &request_body, + ); + + common::attach_tools( + &mut request_body, + openai_tools, + "ai::responses_stream_request", + ); + + request_body +} + +pub(crate) async fn send_stream( + client: &AIClient, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, +) -> Result { + let url = client.config.request_url.clone(); + debug!( + "Responses config: model={}, request_url={}, max_tries={}", + client.config.model, client.config.request_url, max_tries + ); + + let (instructions, response_input) = + OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let openai_tools = OpenAIMessageConverter::convert_tools(tools); + let request_body = build_request_body( + client, + instructions, + response_input, + openai_tools, + extra_body, + ); + + execute_sse_request( + "Responses API", + &url, + &request_body, + max_tries, + || common::apply_headers(client, client.client.post(&url)), + move |response, tx, tx_raw| { + tokio::spawn(handle_responses_stream(response, tx, tx_raw)); + }, + ) + .await +} diff --git a/src/crates/core/src/infrastructure/ai/providers/shared.rs b/src/crates/core/src/infrastructure/ai/providers/shared.rs new file mode 100644 index 000000000..5b0694686 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/shared.rs @@ -0,0 +1,148 @@ +use crate::infrastructure::ai::client::utils::{ + build_request_body_subset, is_trim_custom_request_body_mode, merge_json_value, +}; +use crate::infrastructure::ai::client::AIClient; +use reqwest::RequestBuilder; + +pub(crate) fn apply_header_policy( + client: &AIClient, + builder: RequestBuilder, + apply_defaults: F, +) -> RequestBuilder +where + F: FnOnce(RequestBuilder) -> RequestBuilder, +{ + let has_custom_headers = client + .config + .custom_headers + .as_ref() + .is_some_and(|headers| !headers.is_empty()); + let is_merge_mode = client.config.custom_headers_mode.as_deref() != Some("replace"); + + if has_custom_headers && !is_merge_mode { + return apply_custom_headers(client, builder); + } + + let mut builder = apply_defaults(builder); + + if has_custom_headers && is_merge_mode { + builder = apply_custom_headers(client, builder); + } + + builder +} + +pub(crate) fn apply_custom_headers( + client: &AIClient, + mut builder: RequestBuilder, +) -> RequestBuilder { + if let Some(custom_headers) = &client.config.custom_headers { + if !custom_headers.is_empty() { + for (key, value) in custom_headers { + builder = builder.header(key.as_str(), value.as_str()); + } + } + } + + builder +} + +pub(crate) fn protect_request_body( + client: &AIClient, + request_body: &mut serde_json::Value, + top_level_keys: &[&str], + nested_fields: &[(&str, &str)], +) -> Option { + let protected_body = is_trim_custom_request_body_mode(&client.config) + .then(|| build_request_body_subset(request_body, top_level_keys, nested_fields)); + + if let Some(protected_body) = &protected_body { + *request_body = protected_body.clone(); + } + + protected_body +} + +pub(crate) fn restore_protected_body( + request_body: &mut serde_json::Value, + protected_body: Option, +) { + if let Some(protected_body) = protected_body { + merge_json_value(request_body, protected_body); + } +} + +pub(crate) fn merge_extra_body( + request_body: &mut serde_json::Value, + extra_obj: &serde_json::Map, +) { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } +} + +pub(crate) fn merge_extra_body_recursively( + request_body: &mut serde_json::Value, + extra_obj: serde_json::Map, +) { + for (key, value) in extra_obj { + if let Some(request_obj) = request_body.as_object_mut() { + let target = request_obj.entry(key).or_insert(serde_json::Value::Null); + merge_json_value(target, value); + } + } +} + +pub(crate) fn log_extra_body_keys( + target: &str, + extra_obj: &serde_json::Map, +) { + log::debug!( + target: target, + "Applied extra_body overrides: {:?}", + extra_obj.keys().collect::>() + ); +} + +pub(crate) fn log_request_body(target: &str, label: &str, request_body: &serde_json::Value) { + log::debug!( + target: target, + "{}\n{}", + label, + serde_json::to_string_pretty(request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); +} + +pub(crate) fn log_tool_names(target: &str, tool_names: Vec) { + log::debug!(target: target, "\ntools: {:?}", tool_names); +} + +pub(crate) fn extract_top_level_string_field( + value: &serde_json::Value, + key: &str, +) -> Option { + value + .get(key) + .and_then(serde_json::Value::as_str) + .map(str::to_string) +} + +pub(crate) fn collect_function_declaration_names_or_object_keys( + tool: &serde_json::Value, +) -> Vec { + if let Some(declarations) = tool + .get("functionDeclarations") + .and_then(serde_json::Value::as_array) + { + declarations + .iter() + .filter_map(|declaration| extract_top_level_string_field(declaration, "name")) + .collect() + } else { + tool.as_object() + .into_iter() + .flat_map(|map| map.keys().cloned()) + .collect() + } +} diff --git a/src/crates/core/src/service/announcement/content_loader.rs b/src/crates/core/src/service/announcement/content_loader.rs index 9c8db853f..e32cd4a02 100644 --- a/src/crates/core/src/service/announcement/content_loader.rs +++ b/src/crates/core/src/service/announcement/content_loader.rs @@ -16,7 +16,7 @@ use super::types::{ AnnouncementCard, CardSource, CardType, CompletionAction, ModalConfig, ModalPage, ModalSize, - PageLayout, TriggerCondition, TriggerRule, ToastConfig, + PageLayout, ToastConfig, TriggerCondition, TriggerRule, }; include!(concat!(env!("OUT_DIR"), "/embedded_announcements.rs")); @@ -58,7 +58,9 @@ fn split_front_matter(src: &str) -> Option<(&str, &str)> { } let after_open = &src[3..]; // Skip optional newline immediately after opening `---` - let after_open = after_open.trim_start_matches('\n').trim_start_matches("\r\n"); + let after_open = after_open + .trim_start_matches('\n') + .trim_start_matches("\r\n"); let close = after_open.find("\n---")?; let fm = &after_open[..close]; let body = &after_open[close + 4..]; // skip "\n---" @@ -92,7 +94,11 @@ fn parse_tip_front_matter(fm: &str) -> Option { if id.is_empty() { return None; } - Some(TipFrontMatter { id, nth_open, auto_dismiss_secs }) + Some(TipFrontMatter { + id, + nth_open, + auto_dismiss_secs, + }) } fn parse_feature_front_matter(fm: &str) -> Option { diff --git a/src/crates/core/src/service/announcement/remote.rs b/src/crates/core/src/service/announcement/remote.rs index 3701d0dde..4cf8695c0 100644 --- a/src/crates/core/src/service/announcement/remote.rs +++ b/src/crates/core/src/service/announcement/remote.rs @@ -104,8 +104,8 @@ impl RemoteFetcher { async fn load_disk_cache(&self) -> Vec { match fs::read_to_string(&self.cache_file).await { Ok(content) => { - let cards = serde_json::from_str::>(&content) - .unwrap_or_default(); + let cards = + serde_json::from_str::>(&content).unwrap_or_default(); let mut lock = self.cached.write().await; *lock = cards.clone(); cards @@ -117,11 +117,10 @@ impl RemoteFetcher { async fn remote_url() -> Option { use crate::service::config::get_global_config_service; match get_global_config_service().await { - Ok(svc) => { - svc.get_config::(Some("announcements.remote_url")) - .await - .ok() - } + Ok(svc) => svc + .get_config::(Some("announcements.remote_url")) + .await + .ok(), Err(_) => None, } } diff --git a/src/crates/core/src/service/announcement/scheduler.rs b/src/crates/core/src/service/announcement/scheduler.rs index 87103718a..9e0885e32 100644 --- a/src/crates/core/src/service/announcement/scheduler.rs +++ b/src/crates/core/src/service/announcement/scheduler.rs @@ -4,8 +4,8 @@ //! It updates the persistent state, merges all card sources and returns the //! ordered list of cards that should be presented in this session. -use super::remote::RemoteFetcher; use super::registry::local_cards; +use super::remote::RemoteFetcher; use super::state_store::AnnouncementStateStore; use super::tips_pool::builtin_tips; use super::types::{AnnouncementCard, AnnouncementState, TriggerCondition}; diff --git a/src/crates/core/src/service/announcement/state_store.rs b/src/crates/core/src/service/announcement/state_store.rs index a2b3e11c5..e0776f64f 100644 --- a/src/crates/core/src/service/announcement/state_store.rs +++ b/src/crates/core/src/service/announcement/state_store.rs @@ -26,10 +26,11 @@ impl AnnouncementStateStore { pub async fn load(&self) -> BitFunResult { match fs::read_to_string(&self.state_file).await { Ok(content) => { - let state = serde_json::from_str::(&content).unwrap_or_else(|e| { - warn!("Failed to parse announcement state, using default: {}", e); - AnnouncementState::default() - }); + let state = + serde_json::from_str::(&content).unwrap_or_else(|e| { + warn!("Failed to parse announcement state, using default: {}", e); + AnnouncementState::default() + }); debug!("Loaded announcement state from {:?}", self.state_file); Ok(state) } diff --git a/src/crates/core/src/service/config/types.rs b/src/crates/core/src/service/config/types.rs index f75129e97..2f31e88b4 100644 --- a/src/crates/core/src/service/config/types.rs +++ b/src/crates/core/src/service/config/types.rs @@ -366,6 +366,22 @@ pub enum ModelCategory { SpeechRecognition, } +/// Provider-agnostic reasoning mode. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +#[derive(Default)] +pub enum ReasoningMode { + /// Omit provider-specific reasoning fields and use the upstream API default behavior. + #[default] + Default, + /// Explicitly enable reasoning / thinking output when the provider supports it. + Enabled, + /// Explicitly disable reasoning / thinking output when the provider supports it. + Disabled, + /// Use provider-native adaptive reasoning when supported. + Adaptive, +} + /// Default model configuration. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] @@ -803,7 +819,7 @@ impl Default for SubAgentConfig { } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(default)] +#[serde(default, from = "AIModelConfigCompat")] pub struct AIModelConfig { pub id: String, pub name: String, @@ -823,8 +839,6 @@ pub struct AIModelConfig { pub max_tokens: Option, pub temperature: Option, pub top_p: Option, - pub frequency_penalty: Option, - pub presence_penalty: Option, pub enabled: bool, /// Model category (primary category used for UI filtering). pub category: ModelCategory, @@ -836,14 +850,17 @@ pub struct AIModelConfig { /// Additional metadata (JSON, for extensibility). pub metadata: Option, - /// Whether to display the thinking process (for hybrid/thinking models such as o1). - #[serde(default)] + /// Compatibility-only input field for older saved configs. + /// + /// New code should use `reasoning_mode`. This field is deserialized for migration and + /// compatibility, then omitted from future saves. When `reasoning_mode` is absent, `true` + /// maps to `enabled` and `false` maps to `default`. + #[serde(default, skip_serializing)] pub enable_thinking_process: bool, - /// Whether preserved thinking is supported (Preserved Thinking). - /// If false, `reasoning_content` from previous turns is ignored when sending messages. - #[serde(default)] - pub support_preserved_thinking: bool, + /// Provider-agnostic reasoning mode. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_mode: Option, /// Whether to parse OpenAI-compatible text chunks containing `...` into /// streaming reasoning content. @@ -863,14 +880,109 @@ pub struct AIModelConfig { #[serde(default)] pub skip_ssl_verify: bool, - /// Reasoning effort level for OpenAI Responses API (o-series / GPT-5+). - /// Valid values: "low", "medium", "high", "xhigh". None = use API default. + /// Reasoning effort level for providers that support explicit effort controls. + /// Valid values are provider-specific. None = use API default. #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning_effort: Option, + /// Optional Anthropic manual thinking token budget. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thinking_budget_tokens: Option, + /// Custom request body (JSON string, used to override default request body fields). #[serde(default)] pub custom_request_body: Option, + + /// Custom request body mode: "merge" (default) or "trim" (keep only essential runtime + /// fields, then apply custom JSON). + #[serde(default)] + pub custom_request_body_mode: Option, +} + +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default)] +struct AIModelConfigCompat { + id: String, + name: String, + provider: String, + model_name: String, + base_url: String, + request_url: Option, + api_key: String, + context_window: Option, + max_tokens: Option, + temperature: Option, + top_p: Option, + enabled: bool, + category: ModelCategory, + capabilities: Vec, + recommended_for: Vec, + metadata: Option, + enable_thinking_process: Option, + reasoning_mode: Option, + inline_think_in_text: bool, + custom_headers: Option>, + custom_headers_mode: Option, + skip_ssl_verify: bool, + reasoning_effort: Option, + thinking_budget_tokens: Option, + custom_request_body: Option, + custom_request_body_mode: Option, +} + +impl From for AIModelConfig { + fn from(value: AIModelConfigCompat) -> Self { + let reasoning_mode = value.reasoning_mode.or_else(|| { + value.enable_thinking_process.map(|enabled| { + if enabled { + ReasoningMode::Enabled + } else { + ReasoningMode::Default + } + }) + }); + + Self { + id: value.id, + name: value.name, + provider: value.provider, + model_name: value.model_name, + base_url: value.base_url, + request_url: value.request_url, + api_key: value.api_key, + context_window: value.context_window, + max_tokens: value.max_tokens, + temperature: value.temperature, + top_p: value.top_p, + enabled: value.enabled, + category: value.category, + capabilities: value.capabilities, + recommended_for: value.recommended_for, + metadata: value.metadata, + enable_thinking_process: value.enable_thinking_process.unwrap_or(false), + reasoning_mode, + inline_think_in_text: value.inline_think_in_text, + custom_headers: value.custom_headers, + custom_headers_mode: value.custom_headers_mode, + skip_ssl_verify: value.skip_ssl_verify, + reasoning_effort: value.reasoning_effort, + thinking_budget_tokens: value.thinking_budget_tokens, + custom_request_body: value.custom_request_body, + custom_request_body_mode: value.custom_request_body_mode, + } + } +} + +impl AIModelConfig { + pub fn effective_reasoning_mode(&self) -> ReasoningMode { + self.reasoning_mode.unwrap_or({ + if self.enable_thinking_process { + ReasoningMode::Enabled + } else { + ReasoningMode::Default + } + }) + } } /// Proxy configuration. @@ -1247,21 +1359,21 @@ impl Default for AIModelConfig { max_tokens: None, temperature: None, top_p: None, - frequency_penalty: None, - presence_penalty: None, enabled: false, category: ModelCategory::GeneralChat, capabilities: vec![], recommended_for: vec![], metadata: None, enable_thinking_process: false, - support_preserved_thinking: false, + reasoning_mode: None, inline_think_in_text: false, custom_headers: None, custom_headers_mode: None, skip_ssl_verify: false, reasoning_effort: None, + thinking_budget_tokens: None, custom_request_body: None, + custom_request_body_mode: None, } } } @@ -1429,3 +1541,67 @@ impl AIModelConfig { } } } + +#[cfg(test)] +mod tests { + use super::{AIModelConfig, ReasoningMode}; + + #[test] + fn deserializes_compatibility_thinking_flag_into_reasoning_mode() { + let config: AIModelConfig = serde_json::from_value(serde_json::json!({ + "id": "model_1", + "name": "Provider", + "provider": "openai", + "model_name": "test-model", + "base_url": "https://example.com/v1", + "api_key": "key", + "enabled": true, + "enable_thinking_process": true + })) + .expect("legacy config should deserialize"); + + assert_eq!(config.reasoning_mode, Some(ReasoningMode::Enabled)); + assert!(config.enable_thinking_process); + } + + #[test] + fn deserializes_compatibility_false_thinking_flag_into_default_reasoning_mode() { + let config: AIModelConfig = serde_json::from_value(serde_json::json!({ + "id": "model_1", + "name": "Provider", + "provider": "openai", + "model_name": "test-model", + "base_url": "https://example.com/v1", + "api_key": "key", + "enabled": true, + "enable_thinking_process": false + })) + .expect("legacy config should deserialize"); + + assert_eq!(config.reasoning_mode, Some(ReasoningMode::Default)); + assert!(!config.enable_thinking_process); + } + + #[test] + fn serialization_omits_compatibility_thinking_flag() { + let config: AIModelConfig = serde_json::from_value(serde_json::json!({ + "id": "model_1", + "name": "Provider", + "provider": "openai", + "model_name": "test-model", + "base_url": "https://example.com/v1", + "api_key": "key", + "enabled": true, + "enable_thinking_process": true + })) + .expect("legacy config should deserialize"); + + let value = serde_json::to_value(&config).expect("config should serialize"); + + assert!(value.get("enable_thinking_process").is_none()); + assert_eq!( + value.get("reasoning_mode").and_then(|v| v.as_str()), + Some("enabled") + ); + } +} diff --git a/src/crates/core/src/service/filesystem/listing.rs b/src/crates/core/src/service/filesystem/listing.rs index b29216049..fa4a7be74 100644 --- a/src/crates/core/src/service/filesystem/listing.rs +++ b/src/crates/core/src/service/filesystem/listing.rs @@ -26,7 +26,10 @@ struct TreeEntry { modified_time: SystemTime, } -pub fn list_directory_entries(dir_path: &str, limit: usize) -> BitFunResult> { +pub fn list_directory_entries( + dir_path: &str, + limit: usize, +) -> BitFunResult> { let path = Path::new(dir_path); if !path.exists() { return Err(BitFunError::service(format!( @@ -239,7 +242,10 @@ pub fn format_directory_listing(entries: &[DirectoryListingEntry], dir_path: &st "/".to_string() } } else if parts_for_parent.len() > 1 { - format!("{}/", parts_for_parent[..parts_for_parent.len() - 1].join("/")) + format!( + "{}/", + parts_for_parent[..parts_for_parent.len() - 1].join("/") + ) } else { "/".to_string() }; diff --git a/src/crates/core/src/service/mod.rs b/src/crates/core/src/service/mod.rs index 674b34dc1..63c7c05e2 100644 --- a/src/crates/core/src/service/mod.rs +++ b/src/crates/core/src/service/mod.rs @@ -2,10 +2,10 @@ //! //! Contains core business logic: Workspace, Config, FileSystem, Git, Agentic, AIRules, MCP. -pub mod announcement; // Announcement / feature-demo / tips system pub(crate) mod agent_memory; // Agent memory prompt helpers pub mod ai_memory; // AI memory point management pub mod ai_rules; // AI rules management +pub mod announcement; // Announcement / feature-demo / tips system pub(crate) mod bootstrap; // Workspace persona bootstrap helpers pub mod config; // Config management pub mod cron; // Scheduled jobs @@ -32,6 +32,7 @@ pub use terminal_core as terminal; // Re-export main components. pub use ai_memory::{AIMemory, AIMemoryManager, MemoryType}; pub use ai_rules::AIRulesService; +pub use announcement::{AnnouncementCard, AnnouncementScheduler, AnnouncementSchedulerRef}; pub use bootstrap::reset_workspace_persona_files_to_default; pub use config::{ConfigManager, ConfigProvider, ConfigService}; pub use cron::{ @@ -61,5 +62,4 @@ pub use token_usage::{ ModelTokenStats, SessionTokenStats, TimeRange, TokenUsageQuery, TokenUsageRecord, TokenUsageService, TokenUsageSummary, }; -pub use announcement::{AnnouncementCard, AnnouncementScheduler, AnnouncementSchedulerRef}; pub use workspace::{WorkspaceManager, WorkspaceProvider, WorkspaceService}; diff --git a/src/crates/core/src/service/remote_connect/bot/mod.rs b/src/crates/core/src/service/remote_connect/bot/mod.rs index 2ee11f238..0c0921c9a 100644 --- a/src/crates/core/src/service/remote_connect/bot/mod.rs +++ b/src/crates/core/src/service/remote_connect/bot/mod.rs @@ -646,9 +646,8 @@ mod tests { assert_eq!(paths.len(), 1); assert!(std::path::Path::new(&paths[0]).is_absolute()); - assert!(std::path::Path::new(&paths[0]).ends_with( - std::path::Path::new("artifacts").join("report.pptx") - )); + assert!(std::path::Path::new(&paths[0]) + .ends_with(std::path::Path::new("artifacts").join("report.pptx"))); assert!(std::path::Path::new(&paths[0]).exists()); let _ = std::fs::remove_dir_all(base); } diff --git a/src/crates/core/src/service/remote_connect/remote_server.rs b/src/crates/core/src/service/remote_connect/remote_server.rs index c626a03c1..51b003ef7 100644 --- a/src/crates/core/src/service/remote_connect/remote_server.rs +++ b/src/crates/core/src/service/remote_connect/remote_server.rs @@ -144,7 +144,11 @@ pub struct RemoteModelConfig { pub capabilities: Vec, pub enable_thinking_process: bool, #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_mode: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_budget_tokens: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -173,22 +177,26 @@ async fn load_remote_model_catalog( .map_err(|e| format!("Failed to load global config: {e}"))?; let ai_config: AIConfig = global_config.ai; - let models: Vec = ai_config - .models - .into_iter() - .map(|model| RemoteModelConfig { - id: model.id, - name: model.name, - provider: model.provider, - base_url: model.base_url, - model_name: model.model_name, - context_window: model.context_window, - enabled: model.enabled, - capabilities: model - .capabilities - .into_iter() - .map(|capability| { - match capability { + let models: Vec = + ai_config + .models + .into_iter() + .map(|model| { + let reasoning_mode = model.effective_reasoning_mode(); + + RemoteModelConfig { + id: model.id, + name: model.name, + provider: model.provider, + base_url: model.base_url, + model_name: model.model_name, + context_window: model.context_window, + enabled: model.enabled, + capabilities: model + .capabilities + .into_iter() + .map(|capability| { + match capability { crate::service::config::types::ModelCapability::TextChat => "text_chat", crate::service::config::types::ModelCapability::ImageUnderstanding => { "image_understanding" @@ -209,12 +217,23 @@ async fn load_remote_model_catalog( } } .to_string() - }) - .collect(), - enable_thinking_process: model.enable_thinking_process, - reasoning_effort: model.reasoning_effort, - }) - .collect(); + }) + .collect(), + enable_thinking_process: model.enable_thinking_process, + reasoning_mode: Some( + match reasoning_mode { + crate::service::config::types::ReasoningMode::Default => "default", + crate::service::config::types::ReasoningMode::Enabled => "enabled", + crate::service::config::types::ReasoningMode::Disabled => "disabled", + crate::service::config::types::ReasoningMode::Adaptive => "adaptive", + } + .to_string(), + ), + reasoning_effort: model.reasoning_effort, + thinking_budget_tokens: model.thinking_budget_tokens, + } + }) + .collect(); let session_model_id = if let Some(session_id) = session_id { resolve_session_model_id(session_id).await diff --git a/src/crates/core/src/service/remote_ssh/workspace_state.rs b/src/crates/core/src/service/remote_ssh/workspace_state.rs index c38385674..7e97d648b 100644 --- a/src/crates/core/src/service/remote_ssh/workspace_state.rs +++ b/src/crates/core/src/service/remote_ssh/workspace_state.rs @@ -422,7 +422,8 @@ impl RemoteWorkspaceStateManager { // Assistant sessions use client-local paths under ~/.bitfun/personal_assistant. // A registered remote root of `/` matches every absolute path; without an explicit // `remote_connection_id`, those paths must not be treated as SSH workspaces. - let is_local_assistant_path = get_path_manager_arc().is_local_assistant_workspace_path(path); + let is_local_assistant_path = + get_path_manager_arc().is_local_assistant_workspace_path(path); if is_local_assistant_path { let preferred_connection_id = preferred_connection_id?; let guard = self.registrations.read().await; diff --git a/src/crates/core/src/util/types/config.rs b/src/crates/core/src/util/types/config.rs index 076cfc7d2..09e7f6e68 100644 --- a/src/crates/core/src/util/types/config.rs +++ b/src/crates/core/src/util/types/config.rs @@ -1,4 +1,4 @@ -use crate::service::config::types::AIModelConfig; +use crate::service::config::types::{AIModelConfig, ReasoningMode}; use log::warn; use serde::{Deserialize, Serialize}; @@ -80,22 +80,81 @@ pub struct AIConfig { pub max_tokens: Option, pub temperature: Option, pub top_p: Option, - pub enable_thinking_process: bool, - pub support_preserved_thinking: bool, + pub reasoning_mode: ReasoningMode, pub inline_think_in_text: bool, pub custom_headers: Option>, /// "replace" (default) or "merge" (defaults first, then custom) pub custom_headers_mode: Option, pub skip_ssl_verify: bool, - /// Reasoning effort for OpenAI Responses API ("low", "medium", "high", "xhigh") + /// Provider-specific reasoning effort. pub reasoning_effort: Option, + /// Optional Anthropic manual thinking budget. + pub thinking_budget_tokens: Option, /// Custom JSON overriding default request body fields pub custom_request_body: Option, + /// "merge" (default) or "trim" (keep only essential runtime fields, then apply custom JSON) + pub custom_request_body_mode: Option, +} + +impl TryFrom for AIConfig { + type Error = String; + fn try_from(other: AIModelConfig) -> Result>::Error> { + let reasoning_mode = other.effective_reasoning_mode(); + + // Parse custom request body (convert JSON string to serde_json::Value) + let custom_request_body = if let Some(body_str) = &other.custom_request_body { + match serde_json::from_str::(body_str) { + Ok(value) => Some(value), + Err(e) => { + warn!( + "Failed to parse custom_request_body: {}, config: {}", + e, other.name + ); + None + } + } + } else { + None + }; + + // Use stored request_url if present; otherwise derive from base_url + provider for legacy configs. + let request_url = other + .request_url + .clone() + .filter(|u| !u.is_empty()) + .unwrap_or_else(|| { + resolve_request_url(&other.base_url, &other.provider, &other.model_name) + }); + + Ok(AIConfig { + name: other.name.clone(), + base_url: other.base_url.clone(), + request_url, + api_key: other.api_key.clone(), + model: other.model_name.clone(), + format: other.provider.clone(), + context_window: other.context_window.unwrap_or(128128), + max_tokens: other.max_tokens, + temperature: other.temperature, + top_p: other.top_p, + reasoning_mode, + inline_think_in_text: other.inline_think_in_text, + custom_headers: other.custom_headers, + custom_headers_mode: other.custom_headers_mode, + skip_ssl_verify: other.skip_ssl_verify, + reasoning_effort: other.reasoning_effort, + thinking_budget_tokens: other.thinking_budget_tokens, + custom_request_body, + custom_request_body_mode: other.custom_request_body_mode, + }) + } } #[cfg(test)] mod tests { use super::resolve_request_url; + use super::AIConfig; + use crate::service::config::types::{AIModelConfig, ModelCategory, ReasoningMode}; #[test] fn resolves_openai_request_url() { @@ -164,54 +223,50 @@ mod tests { "https://openrouter.ai/api/v1/chat/completions" ); } -} -impl TryFrom for AIConfig { - type Error = String; - fn try_from(other: AIModelConfig) -> Result>::Error> { - // Parse custom request body (convert JSON string to serde_json::Value) - let custom_request_body = if let Some(body_str) = &other.custom_request_body { - match serde_json::from_str::(body_str) { - Ok(value) => Some(value), - Err(e) => { - warn!( - "Failed to parse custom_request_body: {}, config: {}", - e, other.name - ); - None - } - } - } else { - None - }; + fn base_model_config() -> AIModelConfig { + AIModelConfig { + id: "model_1".to_string(), + name: "Provider".to_string(), + provider: "openai".to_string(), + model_name: "test-model".to_string(), + base_url: "https://example.com/v1".to_string(), + request_url: Some("https://example.com/v1/chat/completions".to_string()), + api_key: "key".to_string(), + context_window: Some(128000), + max_tokens: Some(4096), + temperature: None, + top_p: None, + enabled: true, + category: ModelCategory::GeneralChat, + capabilities: vec![], + recommended_for: vec![], + metadata: None, + enable_thinking_process: false, + reasoning_mode: None, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + thinking_budget_tokens: None, + custom_request_body: None, + custom_request_body_mode: None, + } + } - // Use stored request_url if present; otherwise derive from base_url + provider for legacy configs. - let request_url = other - .request_url - .filter(|u| !u.is_empty()) - .unwrap_or_else(|| { - resolve_request_url(&other.base_url, &other.provider, &other.model_name) - }); + #[test] + fn compatibility_false_thinking_maps_to_default_mode() { + let config = AIConfig::try_from(base_model_config()).expect("conversion should succeed"); + assert_eq!(config.reasoning_mode, ReasoningMode::Default); + } - Ok(AIConfig { - name: other.name.clone(), - base_url: other.base_url.clone(), - request_url, - api_key: other.api_key.clone(), - model: other.model_name.clone(), - format: other.provider.clone(), - context_window: other.context_window.unwrap_or(128128), - max_tokens: other.max_tokens, - temperature: other.temperature, - top_p: other.top_p, - enable_thinking_process: other.enable_thinking_process, - support_preserved_thinking: other.support_preserved_thinking, - inline_think_in_text: other.inline_think_in_text, - custom_headers: other.custom_headers, - custom_headers_mode: other.custom_headers_mode, - skip_ssl_verify: other.skip_ssl_verify, - reasoning_effort: other.reasoning_effort, - custom_request_body, - }) + #[test] + fn compatibility_true_thinking_maps_to_enabled_mode() { + let mut model = base_model_config(); + model.enable_thinking_process = true; + + let config = AIConfig::try_from(model).expect("conversion should succeed"); + assert_eq!(config.reasoning_mode, ReasoningMode::Enabled); } } diff --git a/src/mobile-web/src/pages/ChatPage.tsx b/src/mobile-web/src/pages/ChatPage.tsx index e7830e38b..aab74c189 100644 --- a/src/mobile-web/src/pages/ChatPage.tsx +++ b/src/mobile-web/src/pages/ChatPage.tsx @@ -1719,6 +1719,14 @@ function getModelDisplayName(model: RemoteModelConfig | null): string { return model.model_name || model.name || ''; } +function isReasoningEnabled(model: RemoteModelConfig | null): boolean { + if (!model) return false; + if (model.reasoning_mode) { + return model.reasoning_mode === 'enabled' || model.reasoning_mode === 'adaptive'; + } + return !!model.enable_thinking_process; +} + function getSelectedModelInfo( selectedModelId: string, catalog: RemoteModelCatalog | null, @@ -1747,7 +1755,7 @@ function getSelectedModelInfo( ? (selectedModelId === 'primary' ? t('chat.modelPrimary') : t('chat.modelFast')) : t('chat.modelAuto'), meta: buildModelProviderMeta(resolved) || t('chat.modelAutoDesc'), - enableThinking: !!resolved?.enable_thinking_process, + enableThinking: isReasoningEnabled(resolved), reasoningEffort: resolved?.reasoning_effort, }; } @@ -1764,7 +1772,7 @@ function getSelectedModelInfo( return { label: getModelDisplayName(resolved), meta: buildModelProviderMeta(resolved), - enableThinking: resolved.enable_thinking_process, + enableThinking: isReasoningEnabled(resolved), reasoningEffort: resolved.reasoning_effort, }; } @@ -1914,7 +1922,7 @@ const ModelSelectorPill: React.FC<{ {getModelDisplayName(model)} - {model.enable_thinking_process && ( + {isReasoningEnabled(model) && ( )} diff --git a/src/mobile-web/src/services/RemoteSessionManager.ts b/src/mobile-web/src/services/RemoteSessionManager.ts index a5eb4fc0e..47d8d928e 100644 --- a/src/mobile-web/src/services/RemoteSessionManager.ts +++ b/src/mobile-web/src/services/RemoteSessionManager.ts @@ -53,7 +53,8 @@ export interface RemoteModelConfig { context_window?: number; enabled: boolean; capabilities: string[]; - enable_thinking_process: boolean; + enable_thinking_process?: boolean; + reasoning_mode?: 'default' | 'enabled' | 'disabled' | 'adaptive'; reasoning_effort?: string; } diff --git a/src/web-ui/src/flow_chat/components/ModelSelector.tsx b/src/web-ui/src/flow_chat/components/ModelSelector.tsx index 490942f2e..31cb8dc4f 100644 --- a/src/web-ui/src/flow_chat/components/ModelSelector.tsx +++ b/src/web-ui/src/flow_chat/components/ModelSelector.tsx @@ -13,6 +13,7 @@ import { useTranslation } from 'react-i18next'; import { configManager } from '@/infrastructure/config/services/ConfigManager'; import { agentAPI } from '@/infrastructure/api/service-api/AgentAPI'; import { getProviderDisplayName } from '@/infrastructure/config/services/modelConfigs'; +import { getEffectiveReasoningMode, isReasoningVisiblyEnabled } from '@/infrastructure/config/utils/reasoning'; import { globalEventBus } from '@/infrastructure/event-bus'; import type { AIModelConfig } from '@/infrastructure/config/types'; import { Tooltip } from '@/component-library'; @@ -223,7 +224,7 @@ export const ModelSelector: React.FC = ({ providerName: getProviderDisplayName(model), provider: model.provider, contextWindow: model.context_window, - enableThinking: model.enable_thinking_process, + enableThinking: isReasoningVisiblyEnabled(getEffectiveReasoningMode(model)), reasoningEffort: model.reasoning_effort, }; } @@ -238,7 +239,7 @@ export const ModelSelector: React.FC = ({ providerName: getProviderDisplayName(model), provider: model.provider, contextWindow: model.context_window, - enableThinking: model.enable_thinking_process, + enableThinking: isReasoningVisiblyEnabled(getEffectiveReasoningMode(model)), reasoningEffort: model.reasoning_effort, }; }, [getCurrentModelId, allModels, defaultModels, t]); @@ -258,7 +259,7 @@ export const ModelSelector: React.FC = ({ providerName: getProviderDisplayName(m), provider: m.provider, contextWindow: m.context_window, - enableThinking: m.enable_thinking_process, + enableThinking: isReasoningVisiblyEnabled(getEffectiveReasoningMode(m)), reasoningEffort: m.reasoning_effort, })); }, [allModels]); diff --git a/src/web-ui/src/infrastructure/config/components/AIModelConfig.scss b/src/web-ui/src/infrastructure/config/components/AIModelConfig.scss index 8058deb48..78a95ce3e 100644 --- a/src/web-ui/src/infrastructure/config/components/AIModelConfig.scss +++ b/src/web-ui/src/infrastructure/config/components/AIModelConfig.scss @@ -258,6 +258,18 @@ width: 100%; } + &__warning-inline { + display: inline-flex; + align-items: flex-start; + gap: $size-gap-2; + color: var(--color-warning); + + svg { + flex-shrink: 0; + margin-top: 1px; + } + } + &__json-status { font-size: var(--font-size-xs); @@ -832,7 +844,12 @@ grid-template-columns: 1fr !important; } - .bitfun-config-page-row:not(.bitfun-config-page-row--multiline), + .bitfun-config-page-row.bitfun-ai-model-config__toggle-row { + grid-template-columns: minmax(0, 1fr) auto !important; + gap: $size-gap-3; + } + + .bitfun-config-page-row:not(.bitfun-config-page-row--multiline):not(.bitfun-ai-model-config__toggle-row), .bitfun-config-page-row--multiline.bitfun-config-page-row--wide { grid-template-columns: minmax(0, 2fr) minmax(0, 11fr) !important; } @@ -849,6 +866,11 @@ padding: $size-gap-3 $size-gap-4; } + .bitfun-config-page-row.bitfun-ai-model-config__toggle-row { + grid-template-columns: minmax(0, 1fr) auto !important; + gap: $size-gap-3; + } + .bitfun-config-page-row--multiline.bitfun-config-page-row--wide { grid-template-columns: minmax(0, 2fr) minmax(0, 8fr) !important; gap: $size-gap-4; @@ -880,6 +902,12 @@ } } + .bitfun-ai-model-config__toggle-row .bitfun-config-page-row__control { + width: auto; + min-width: auto; + justify-self: end; + } + .bitfun-config-page-row--multiline .bitfun-config-page-row__control { flex-direction: column; align-items: stretch; @@ -1464,75 +1492,59 @@ - &__radio-label { + &__inline-header { display: flex; - align-items: center; - cursor: pointer; - font-size: var(--font-size-sm); - color: var(--color-text-secondary); - padding: $size-gap-2 $size-gap-3; - border-radius: $size-radius-sm; - border: 1px solid transparent; - background: transparent; - transition: all $motion-base $easing-standard; + align-items: flex-start; + justify-content: space-between; + gap: $size-gap-3; + width: 100%; + flex-wrap: wrap; + } - input[type="radio"] { - width: 14px; - height: 14px; - flex-shrink: 0; - accent-color: var(--color-accent-500); - cursor: pointer; - } + &__inline-header-main { + display: inline-flex; + align-items: center; + gap: $size-gap-2; + min-width: 0; + } - span { - margin-left: $size-gap-2; - } + &__inline-header-info { + display: inline-flex; + align-items: center; + justify-content: center; + width: 18px; + height: 18px; + border-radius: 999px; + color: var(--color-text-muted); + cursor: help; + transition: color $motion-base $easing-standard, background-color $motion-base $easing-standard; - &:hover { + &:hover, + &:focus-visible { color: var(--color-text-primary); background: var(--element-bg-subtle); - } - - - &:has(input[type="radio"]:checked) { - color: var(--color-accent-500); - background: var(--color-accent-100); - border-color: var(--color-accent-500); - font-weight: $font-weight-medium; + outline: none; } } - - &__header-mode { - padding: $size-gap-4; - background: var(--element-bg-subtle); - border-radius: $size-radius-sm; - margin-bottom: $size-gap-3; - border: 1px solid var(--border-base); - - > label { - font-size: var(--font-size-xs) !important; - color: var(--color-text-secondary) !important; - margin-bottom: $size-gap-3; - display: block; - font-weight: $font-weight-medium; - } + &__inline-header-actions { + display: inline-flex; + align-items: center; + gap: $size-gap-2; + margin-left: auto; + flex-shrink: 0; + } - > div { - display: flex; - gap: $size-gap-3; - flex-wrap: wrap; - } + &__mode-button { + min-width: 56px; + } - > small { - display: block; - color: var(--color-text-muted); - font-size: 11px; - margin-top: $size-gap-3; - line-height: 1.4; - padding-top: $size-gap-2; - border-top: 1px dashed var(--border-base); - } + &__header-tooltip { + display: flex; + flex-direction: column; + gap: $size-gap-2; + max-width: 320px; + line-height: 1.5; } &__warning { diff --git a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx index a471dca77..f22b1912e 100644 --- a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx +++ b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx @@ -1,15 +1,17 @@ import React, { useState, useEffect, useMemo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { Plus, SquarePen, Trash2, Wifi, Loader, AlertTriangle, X, Settings, ExternalLink, Eye, EyeOff, ChevronDown, ChevronRight } from 'lucide-react'; -import { Button, Switch, Select, IconButton, NumberInput, Card, Checkbox, Modal, Input, Textarea, type SelectOption } from '@/component-library'; +import { Plus, SquarePen, Trash2, Wifi, Loader, AlertTriangle, X, Settings, ExternalLink, Eye, EyeOff, ChevronDown, ChevronRight, Info } from 'lucide-react'; +import { Button, Switch, Select, IconButton, NumberInput, Card, Modal, Input, Textarea, Tooltip, type SelectOption } from '@/component-library'; import { AIModelConfig as AIModelConfigType, ProxyConfig, ModelCategory, - ModelCapability + ModelCapability, + ReasoningMode } from '../types'; import { configManager } from '../services/ConfigManager'; import { PROVIDER_TEMPLATES, getModelDisplayName, getProviderDisplayName, getProviderTemplateId } from '../services/modelConfigs'; +import { DEFAULT_REASONING_MODE, getEffectiveReasoningMode, supportsAnthropicAdaptive, supportsAnthropicReasoning, supportsAnthropicThinkingBudget, supportsResponsesReasoning } from '../utils/reasoning'; import { aiApi, systemAPI } from '@/infrastructure/api'; import { useNotification } from '@/shared/notification-system'; import { ConfigPageHeader, ConfigPageLayout, ConfigPageContent, ConfigPageSection, ConfigPageRow, ConfigCollectionItem } from './common'; @@ -32,7 +34,9 @@ interface SelectedModelDraft { category: ModelCategory; contextWindow: number; maxTokens: number; - enableThinking: boolean; + reasoningMode: ReasoningMode; + reasoningEffort?: string; + thinkingBudgetTokens?: number; } interface ProviderGroup { @@ -42,7 +46,7 @@ interface ProviderGroup { } function isResponsesProvider(provider?: string): boolean { - return provider === 'response' || provider === 'responses'; + return supportsResponsesReasoning(provider); } function createModelDraft( @@ -59,7 +63,39 @@ function createModelDraft( category: overrides?.category ?? baseConfig?.category ?? 'general_chat', contextWindow: overrides?.contextWindow ?? baseConfig?.context_window ?? 128000, maxTokens: overrides?.maxTokens ?? baseConfig?.max_tokens ?? 8192, - enableThinking: overrides?.enableThinking ?? baseConfig?.enable_thinking_process ?? false, + reasoningMode: overrides?.reasoningMode ?? getEffectiveReasoningMode(baseConfig), + reasoningEffort: overrides?.reasoningEffort ?? baseConfig?.reasoning_effort, + thinkingBudgetTokens: overrides?.thinkingBudgetTokens ?? baseConfig?.thinking_budget_tokens, + }; +} + +function normalizeDraftReasoningForProvider( + draft: SelectedModelDraft, + provider?: string +): SelectedModelDraft { + let reasoningMode = draft.reasoningMode; + + if (supportsResponsesReasoning(provider)) { + reasoningMode = DEFAULT_REASONING_MODE; + } else if (!supportsAnthropicReasoning(provider) && reasoningMode === 'adaptive') { + reasoningMode = 'enabled'; + } else if (supportsAnthropicReasoning(provider) + && reasoningMode === 'adaptive' + && !supportsAnthropicAdaptive(draft.modelName)) { + reasoningMode = 'enabled'; + } + + const keepReasoningEffort = supportsResponsesReasoning(provider) + || (supportsAnthropicReasoning(provider) && reasoningMode === 'adaptive'); + const keepThinkingBudget = supportsAnthropicReasoning(provider) + && reasoningMode === 'enabled' + && supportsAnthropicThinkingBudget(draft.modelName); + + return { + ...draft, + reasoningMode, + reasoningEffort: keepReasoningEffort ? draft.reasoningEffort : undefined, + thinkingBudgetTokens: keepThinkingBudget ? draft.thinkingBudgetTokens : undefined, }; } @@ -257,8 +293,10 @@ const AIModelConfig: React.FC = () => { [requestFormatOptions] ); - const reasoningEffortOptions = useMemo( + const responsesReasoningEffortOptions = useMemo( () => [ + { label: 'None', value: 'none' }, + { label: 'Minimal', value: 'minimal' }, { label: 'Low', value: 'low' }, { label: 'Medium', value: 'medium' }, { label: 'High', value: 'high' }, @@ -267,14 +305,33 @@ const AIModelConfig: React.FC = () => { [] ); - const thinkingModeOptions = useMemo( + const anthropicReasoningEffortOptions = useMemo( () => [ - { label: t('thinking.optionEnabled'), value: 'enabled' }, - { label: t('thinking.optionDisabled'), value: 'disabled' }, + { label: 'Low', value: 'low' }, + { label: 'Medium', value: 'medium' }, + { label: 'High', value: 'high' }, + { label: 'Max', value: 'max' }, ], - [t] + [] ); + const buildReasoningModeOptions = useCallback((provider?: string, modelName?: string, currentMode?: ReasoningMode): SelectOption[] => { + const options: SelectOption[] = [ + { label: t('thinking.optionDefault'), value: DEFAULT_REASONING_MODE }, + { label: t('thinking.optionEnabled'), value: 'enabled' }, + { label: t('thinking.optionDisabled'), value: 'disabled' }, + ]; + + if ( + supportsAnthropicReasoning(provider) + && (supportsAnthropicAdaptive(modelName) || currentMode === 'adaptive') + ) { + options.push({ label: t('thinking.optionAdaptive'), value: 'adaptive' }); + } + + return options; + }, [t]); + const categoryOptions = useMemo( () => [ { label: t('category.general_chat'), value: 'general_chat' }, @@ -291,6 +348,26 @@ const AIModelConfig: React.FC = () => { [t] ); + const getCustomRequestBodyTrimHint = useCallback((provider?: string): string => { + switch (provider) { + case 'responses': + return t('advancedSettings.customRequestBody.trimHintResponses'); + case 'anthropic': + return t('advancedSettings.customRequestBody.trimHintAnthropic'); + case 'gemini': + return t('advancedSettings.customRequestBody.trimHintGemini'); + case 'openai': + default: + return t('advancedSettings.customRequestBody.trimHintOpenAI'); + } + }, [t]); + + const getCustomRequestBodyModeHint = useCallback((provider?: string, mode?: string | null): string => { + return mode === 'trim' + ? getCustomRequestBodyTrimHint(provider) + : t('advancedSettings.customRequestBody.modeMergeHint'); + }, [getCustomRequestBodyTrimHint, t]); + const loadConfig = useCallback(async () => { try { @@ -362,7 +439,9 @@ const AIModelConfig: React.FC = () => { configId: config.id, contextWindow: config.context_window || 128000, maxTokens: config.max_tokens || 8192, - enableThinking: config.enable_thinking_process ?? false, + reasoningMode: getEffectiveReasoningMode(config), + reasoningEffort: config.reasoning_effort, + thinkingBudgetTokens: config.thinking_budget_tokens, })) ); @@ -545,7 +624,6 @@ const AIModelConfig: React.FC = () => { base_url: resolvedBaseUrl, request_url: config.request_url || resolveRequestUrl(resolvedBaseUrl, resolvedProvider, resolvedModelName), model_name: resolvedModelName, - description: config.description, context_window: config.context_window || 128000, max_tokens: config.max_tokens || 8192, temperature: config.temperature, @@ -555,14 +633,15 @@ const AIModelConfig: React.FC = () => { capabilities: config.capabilities || ['text_chat'], recommended_for: config.recommended_for || [], metadata: config.metadata || {}, - enable_thinking_process: config.enable_thinking_process ?? false, - support_preserved_thinking: config.support_preserved_thinking ?? false, + reasoning_mode: config.reasoning_mode ?? getEffectiveReasoningMode(config), inline_think_in_text: config.inline_think_in_text ?? false, reasoning_effort: config.reasoning_effort, + thinking_budget_tokens: config.thinking_budget_tokens, custom_headers: config.custom_headers, custom_headers_mode: config.custom_headers_mode, skip_ssl_verify: config.skip_ssl_verify ?? false, - custom_request_body: config.custom_request_body + custom_request_body: config.custom_request_body, + custom_request_body_mode: config.custom_request_body_mode, }; }; @@ -576,6 +655,7 @@ const AIModelConfig: React.FC = () => { custom_headers_mode: config.custom_headers_mode || null, custom_headers: config.custom_headers || null, custom_request_body: config.custom_request_body || null, + custom_request_body_mode: config.custom_request_body_mode || null, }); const fetchRemoteModels = async (config: Partial | null) => { @@ -688,7 +768,7 @@ const AIModelConfig: React.FC = () => { : (defaultModel ? [createModelDraft(defaultModel, { context_window: 128000, max_tokens: 8192, - enable_thinking_process: false, + reasoning_mode: DEFAULT_REASONING_MODE, })] : []) ); setShowAdvancedSettings(false); @@ -742,26 +822,24 @@ const AIModelConfig: React.FC = () => { model_name: '', provider: config.provider, enabled: true, - description: config.description, context_window: config.context_window || 128000, max_tokens: config.max_tokens || 8192, category: config.category || 'general_chat', capabilities: config.capabilities || getCapabilitiesByCategory(config.category || 'general_chat'), recommended_for: config.recommended_for || [], metadata: config.metadata || {}, - enable_thinking_process: config.enable_thinking_process ?? false, - support_preserved_thinking: config.support_preserved_thinking ?? false, inline_think_in_text: config.inline_think_in_text ?? false, - reasoning_effort: config.reasoning_effort, custom_headers: config.custom_headers, custom_headers_mode: config.custom_headers_mode, skip_ssl_verify: config.skip_ssl_verify ?? false, custom_request_body: config.custom_request_body, + custom_request_body_mode: config.custom_request_body_mode, }); setSelectedModelDrafts(createDraftsFromConfigs(configuredProviderModels)); setShowAdvancedSettings( !!config.inline_think_in_text || !!config.skip_ssl_verify || + config.custom_request_body_mode === 'trim' || (!!config.custom_request_body && config.custom_request_body.trim() !== '') || (!!config.custom_headers && Object.keys(config.custom_headers).length > 0) ); @@ -778,7 +856,9 @@ const AIModelConfig: React.FC = () => { createModelDraft(config.model_name, config, { contextWindow: config.context_window || 128000, maxTokens: config.max_tokens || 8192, - enableThinking: config.enable_thinking_process ?? false, + reasoningMode: getEffectiveReasoningMode(config), + reasoningEffort: config.reasoning_effort, + thinkingBudgetTokens: config.thinking_budget_tokens, }) ]); @@ -787,6 +867,7 @@ const AIModelConfig: React.FC = () => { setShowAdvancedSettings( hasCustomHeaders || hasCustomBody || + config.custom_request_body_mode === 'trim' || !!config.skip_ssl_verify || !!config.inline_think_in_text ); @@ -833,21 +914,21 @@ const AIModelConfig: React.FC = () => { model_name: draft.modelName, provider: editingConfig.provider || 'openai', enabled: editingConfig.enabled ?? true, - description: editingConfig.description, context_window: draft.contextWindow, max_tokens: draft.maxTokens, category: draft.category, capabilities: getCapabilitiesByCategory(draft.category), recommended_for: editingConfig.recommended_for || [], metadata: editingConfig.metadata, - enable_thinking_process: draft.enableThinking, - support_preserved_thinking: editingConfig.support_preserved_thinking ?? false, + reasoning_mode: draft.reasoningMode, inline_think_in_text: editingConfig.inline_think_in_text ?? false, - reasoning_effort: editingConfig.reasoning_effort, + reasoning_effort: draft.reasoningEffort, + thinking_budget_tokens: draft.thinkingBudgetTokens, custom_headers: editingConfig.custom_headers, custom_headers_mode: editingConfig.custom_headers_mode, skip_ssl_verify: editingConfig.skip_ssl_verify ?? false, - custom_request_body: editingConfig.custom_request_body + custom_request_body: editingConfig.custom_request_body, + custom_request_body_mode: editingConfig.custom_request_body_mode, }; }); @@ -1264,6 +1345,43 @@ const AIModelConfig: React.FC = () => { ); + const formatReasoningSummary = (draft: SelectedModelDraft) => { + const parts: string[] = []; + + switch (draft.reasoningMode) { + case 'enabled': + parts.push(t('thinking.summaryEnabled')); + break; + case 'disabled': + parts.push(t('thinking.summaryDisabled')); + break; + case 'adaptive': + parts.push(t('thinking.summaryAdaptive')); + break; + default: + parts.push(t('thinking.summaryDefault')); + break; + } + + if (draft.reasoningEffort) { + parts.push(draft.reasoningEffort); + } + + return parts.join(' · '); + }; + + const getDraftReasoningEffortOptions = (provider?: string) => { + if (supportsResponsesReasoning(provider)) { + return responsesReasoningEffortOptions; + } + + if (supportsAnthropicReasoning(provider)) { + return anthropicReasoningEffortOptions; + } + + return []; + }; + const renderSelectedModelRows = () => { if (selectedModelDrafts.length === 0) { return ( @@ -1280,6 +1398,19 @@ const AIModelConfig: React.FC = () => { const categoryLabel = categoryCompactLabels[draft.category] ?? draft.category; const canToggleExpand = selectedModelDrafts.length > 1; const modelDisplayName = draft.modelName; + const reasoningModeOptions = buildReasoningModeOptions(editingConfig.provider, draft.modelName, draft.reasoningMode); + const reasoningEffortOptions = getDraftReasoningEffortOptions(editingConfig.provider); + const showReasoningModeControl = !supportsResponsesReasoning(editingConfig.provider); + const showReasoningEffortControl = reasoningEffortOptions.length > 0 + && ( + supportsResponsesReasoning(editingConfig.provider) + || (supportsAnthropicReasoning(editingConfig.provider) && draft.reasoningMode === 'adaptive') + ); + const showThinkingBudgetControl = supportsAnthropicReasoning(editingConfig.provider) + && draft.reasoningMode === 'enabled' + && supportsAnthropicThinkingBudget(draft.modelName); + const displayedThinkingBudget = draft.thinkingBudgetTokens + ?? Math.min(Math.floor(draft.maxTokens * 0.75), 10000); return (
@@ -1335,7 +1466,7 @@ const AIModelConfig: React.FC = () => { {' · '} {formatTokenCountShort(draft.maxTokens)} out {' · '} - {draft.enableThinking ? t('thinking.summaryOn') : t('thinking.summaryOff')} + {formatReasoningSummary(draft)}
)} @@ -1389,15 +1520,43 @@ const AIModelConfig: React.FC = () => { disableWheel /> -
- {t('thinking.enable')} - updateModelDraft(draft.modelName, { reasoningMode: value as ReasoningMode })} + options={reasoningModeOptions} + size="small" + /> +
+ )} + {showReasoningEffortControl && ( +
+ {t('reasoningEffort.label')} + { + const provider = value as string; resetRemoteModelDiscovery(); + setSelectedModelDrafts(prevDrafts => + prevDrafts.map(draft => normalizeDraftReasoningForProvider(draft, provider)) + ); setEditingConfig(prev => ({ ...prev, - provider: value as string, - request_url: resolveRequestUrl(prev?.base_url || '', value as string, prev?.model_name || '') + provider, + request_url: resolveRequestUrl(prev?.base_url || '', provider, prev?.model_name || '') })); }} placeholder={t('form.providerPlaceholder')} @@ -1545,11 +1708,6 @@ const AIModelConfig: React.FC = () => { {renderSelectedModelRows()}
- {isResponsesProvider(editingConfig.provider) && ( - - { const provider = value as string; resetRemoteModelDiscovery(); + setSelectedModelDrafts(prevDrafts => + prevDrafts.map(draft => normalizeDraftReasoningForProvider(draft, provider)) + ); setEditingConfig(prev => ({ ...prev, provider, request_url: resolveRequestUrl(prev?.base_url || '', provider, prev?.model_name || ''), inline_think_in_text: provider === 'openai' ? (prev?.inline_think_in_text ?? false) : false, - reasoning_effort: isResponsesProvider(provider) ? (prev?.reasoning_effort || 'medium') : undefined, })); }} placeholder={t('form.providerPlaceholder')} options={requestFormatOptions} size="small" /> @@ -1668,14 +1828,6 @@ const AIModelConfig: React.FC = () => { {renderSelectedModelRows()} - {isResponsesProvider(editingConfig.provider) && ( - -