diff --git a/src/reloaded-code-serdesai/Cargo.toml b/src/reloaded-code-serdesai/Cargo.toml index cb5cc9c..09ac1bd 100644 --- a/src/reloaded-code-serdesai/Cargo.toml +++ b/src/reloaded-code-serdesai/Cargo.toml @@ -48,10 +48,11 @@ mistral = ["serdes-ai-models/mistral"] ollama = ["serdes-ai-models/ollama"] openrouter = ["serdes-ai-models/openrouter"] # Sandbox feature - enables bubblewrap sandboxing -linux-bubblewrap = [ - "dep:reloaded-code-bubblewrap", - "reloaded-code-core/linux-bubblewrap", -] +linux-bubblewrap = ["dep:reloaded-code-bubblewrap", "reloaded-code-core/linux-bubblewrap"] + +# Mock feature - enables mock models types and model_override injection +# Use for testing functionality with mocks. +mock = [] [dependencies] # Core tool operations (file read/write/edit, glob, grep, bash, etc.) diff --git a/src/reloaded-code-serdesai/src/agent_runtime/task.rs b/src/reloaded-code-serdesai/src/agent_runtime/task.rs index 1245767..8e7f1f7 100644 --- a/src/reloaded-code-serdesai/src/agent_runtime/task.rs +++ b/src/reloaded-code-serdesai/src/agent_runtime/task.rs @@ -8,6 +8,8 @@ use crate::task::TaskHandle; use reloaded_code_agents::AgentRuntime; use reloaded_code_core::{CredentialLookup, CredentialResolver, models::ModelCatalog}; use serdes_ai::{Agent, AgentBuilder}; +#[cfg(any(test, feature = "mock"))] +use serdes_ai_models::BoxedModel; use std::path::Path; use std::sync::Arc; @@ -58,6 +60,8 @@ where model_catalog, credentials, workspace_root, + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -198,6 +202,26 @@ where pub fn credentials(&self) -> &C { self.context.credentials.as_ref() } + + /// Sets a mock model that overrides the resolved catalog model. + /// + /// # Arguments + /// - `model`: Any [`serdes_ai_models::Model`] implementation to use instead + /// of the catalog-resolved model. + /// + /// # Returns + /// `Self` for chaining. + /// + /// # Panics + /// Panics if the [`AgentBuildContext`] has already been cloned (i.e., the + /// inner `Arc` is not unique). This must be called before sharing the context. + #[cfg(any(test, feature = "mock"))] + pub fn with_model_override(mut self, model: impl serdes_ai_models::Model + 'static) -> Self { + Arc::get_mut(&mut self.context) + .expect("with_model_override must be called before sharing the context") + .model_override = Some(Arc::new(model)); + self + } } /// Shared owned state for builds that may happen later during Task delegation. @@ -208,6 +232,8 @@ pub(crate) struct TaskBuildContext, credentials: Arc, workspace_root: Arc, + #[cfg(any(test, feature = "mock"))] + model_override: Option, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: Option>, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -251,6 +277,8 @@ where model_catalog, credentials, workspace_root, + #[cfg(any(test, feature = "mock"))] + model_override: None, bash_sandbox: Some(bash_sandbox), _sandbox_tmpdir, } @@ -274,6 +302,8 @@ where model_catalog, credentials, workspace_root, + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -329,8 +359,15 @@ where context.credentials.as_ref(), with_summaries, )?; - // Create an AgentBuilder pre-loaded with the resolved model. - let builder = AgentBuilder::<(), String>::from_arc(prepared.model().clone()); + // Create an AgentBuilder with the model (override wins over catalog-resolved). + #[cfg(any(test, feature = "mock"))] + let model = context + .model_override + .clone() + .unwrap_or_else(|| prepared.model().clone()); + #[cfg(not(any(test, feature = "mock")))] + let model = prepared.model().clone(); + let builder = AgentBuilder::<(), String>::from_arc(model); // Create a TaskHandle for delegation if Task tool is attached later. let task_handle = TaskHandle::new(context.clone(), current_depth); // Select the sandbox profile (None on non-Linux or without the feature). @@ -464,6 +501,8 @@ mod tests { model_catalog, credentials, workspace_root: workspace_root(), + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -504,6 +543,8 @@ mod tests { model_catalog, credentials, workspace_root: workspace_root(), + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -543,6 +584,8 @@ mod tests { model_catalog, credentials, workspace_root: workspace_root(), + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -578,6 +621,8 @@ mod tests { model_catalog, credentials, workspace_root: workspace_root(), + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] @@ -650,6 +695,8 @@ mod tests { model_catalog, credentials, workspace_root: workspace_root(), + #[cfg(any(test, feature = "mock"))] + model_override: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] bash_sandbox: None, #[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))] diff --git a/src/reloaded-code-serdesai/src/lib.rs b/src/reloaded-code-serdesai/src/lib.rs index a93096f..73e995c 100644 --- a/src/reloaded-code-serdesai/src/lib.rs +++ b/src/reloaded-code-serdesai/src/lib.rs @@ -47,3 +47,6 @@ pub use reloaded_code_agents::{ AgentDefaults, AgentRuntime, AgentRuntimeBuilder, ModelResolutionError, ResolvedModel, resolve_model_with_catalog, }; + +#[cfg(any(test, feature = "mock"))] +pub mod mock; diff --git a/src/reloaded-code-serdesai/src/mock.rs b/src/reloaded-code-serdesai/src/mock.rs new file mode 100644 index 0000000..e696de1 --- /dev/null +++ b/src/reloaded-code-serdesai/src/mock.rs @@ -0,0 +1,235 @@ +//! Mock model types with streaming support for running agents without a real LLM provider. +//! +//! Wraps upstream [`serdes_ai_models`] mock types so they work with +//! [`Agent::run_stream`][`serdes_ai::Agent::run_stream`]. +//! +//! # Quick start +//! +//! ```text +//! use reloaded_code_serdesai::mock::{Streamed, tool_then_text}; +//! use serde_json::json; +//! +//! let model = tool_then_text("glob", json!({"pattern": "*.rs"}), "Done."); +//! let stream = agent.run_stream("prompt", ()).await?; // OK +//! ``` +//! +//! When using [`crate::AgentBuildContext`], call +//! [`with_model_override`](crate::AgentBuildContext::with_model_override) +//! to inject the mock model before calling [`build()`](crate::AgentBuildContext::build). + +// Re-export upstream mock types so users can still access the raw variants when needed. +pub use serdes_ai_models::{FunctionModel, MockModel, TestModel}; + +use async_trait::async_trait; +use futures::stream; +use serdes_ai::core::{ + FinishReason, ModelRequest, ModelResponse, ModelResponsePart, ModelResponseStreamEvent, +}; +use serdes_ai_models::Model as ModelTrait; +// Re-export the types from where serdes-ai-models exposes them. +use serdes_ai::core::ModelSettings; +use serdes_ai_models::{ + ModelCapability, ModelError, ModelProfile, ModelRequestParameters, StreamedResponse, +}; + +// ============================================================================ +// Streamed - wrapper that adds streaming support to any Model +// ============================================================================ + +/// Wrapper adding [`request_stream`](ModelTrait::request_stream) support to any [`ModelTrait`] implementation. +/// +/// Delegates [`request`](ModelTrait::request) directly to the inner model and converts the non-streaming +/// response into a sequence of [`ModelResponseStreamEvent`]s for streaming callers. +/// +/// # Example +/// +/// ```rust,no_run +/// use reloaded_code_serdesai::mock::{FunctionModel, Streamed}; +/// use serde_json::json; +/// +/// let model = Streamed::new(FunctionModel::tool_call("glob", json!({"pattern": "*.rs"}))); +/// ``` +#[derive(Clone, Debug)] +pub struct Streamed { + inner: T, + name: String, +} + +impl Streamed { + /// Wrap a model to add streaming support. + /// + /// The `name` defaults to the inner model's [`name()`](ModelTrait::name). + pub fn new(inner: T) -> Self + where + T: ModelTrait, + { + let name = inner.name().to_string(); + Self { inner, name } + } + + /// Set a custom name for the wrapped model. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } +} + +#[async_trait] +impl ModelTrait for Streamed { + fn name(&self) -> &str { + &self.name + } + + fn system(&self) -> &str { + self.inner.system() + } + + fn profile(&self) -> &ModelProfile { + self.inner.profile() + } + + async fn request( + &self, + messages: &[ModelRequest], + settings: &ModelSettings, + params: &ModelRequestParameters, + ) -> Result { + self.inner.request(messages, settings, params).await + } + + async fn request_stream( + &self, + messages: &[ModelRequest], + settings: &ModelSettings, + params: &ModelRequestParameters, + ) -> Result { + let response = self.inner.request(messages, settings, params).await?; + let events = response_to_stream_events(response); + Ok(Box::pin(stream::iter(events.into_iter().map(Ok)))) + } + + fn supports(&self, capability: ModelCapability) -> bool { + self.inner.supports(capability) + } +} + +// ============================================================================ +// Convenience helpers +// ============================================================================ + +/// Build a mock model that calls `tool_name` with `args` on the **first** turn, +/// then returns text that incorporates the real tool return on the **second** turn. +/// +/// This prevents infinite loops when running agent examples that stream, +/// because after the tool result is fed back the model answers with text. +/// +/// The second-turn response includes whatever the real tool returned, so +/// the output reflects actual tool execution rather than a canned message. +/// +/// # Example +/// +/// ```rust,no_run +/// use reloaded_code_serdesai::mock::tool_then_text; +/// use serde_json::json; +/// +/// let model = tool_then_text("glob", json!({"pattern": "*.rs"}), "Done."); +/// ``` +pub fn tool_then_text( + tool_name: impl Into, + args: serde_json::Value, + fallback_text: impl Into, +) -> Streamed { + let tool_name = tool_name.into(); + let fallback_text = fallback_text.into(); + let tool_name_clone = tool_name.clone(); + + let model = FunctionModel::new(move |messages, _settings| { + // Check whether the conversation already contains a tool return from a + // previous turn. If it does, we are on the second call and should + // produce a text response incorporating the real result. + let has_tool_return = messages.iter().any(|m| { + m.parts + .iter() + .any(|p| matches!(p, serdes_ai::core::ModelRequestPart::ToolReturn(_))) + }); + + if has_tool_return { + // Collect tool return content from the message history. + let tool_results: String = messages + .iter() + .flat_map(|m| m.tool_returns()) + .map(extract_tool_return_text) + .collect::>() + .join("\n"); + + let text = if tool_results.is_empty() { + fallback_text.clone() + } else { + format!("{fallback_text}\n\n{tool_results}") + }; + + ModelResponse::text(text) + } else { + // First call: emit a tool call so the agent executes the real tool. + ModelResponse::with_parts(vec![ + ModelResponsePart::text(format!("Calling {tool_name}...")), + ModelResponsePart::tool_call(tool_name_clone.clone(), args.clone()), + ]) + .with_finish_reason(FinishReason::ToolCall) + } + }); + + Streamed::new(model) +} + +// ============================================================================ +// Private helpers +// ============================================================================ + +fn response_to_stream_events(response: ModelResponse) -> Vec { + let mut events = Vec::with_capacity(response.parts.len() * 2 + 1); + + for (index, part) in response.parts.into_iter().enumerate() { + events.push(ModelResponseStreamEvent::part_start(index, part)); + events.push(ModelResponseStreamEvent::part_end(index)); + } + + events +} + +/// Extract human-readable text from a [`ToolReturnPart`]. +/// +/// Uses serde JSON round-tripping to avoid depending on the +/// non-public `ToolReturnContent` enum variants directly. +fn extract_tool_return_text(tr: &serdes_ai::core::ToolReturnPart) -> String { + // Serialize the content field to JSON, then extract readable text. + // ToolReturnContent variants produce: + // Text -> {"type":"text","content":"..."} + // Json -> {"type":"json","content":{...}} + // Error -> {"type":"error","message":"..."} + // Multiple-> {"type":"multiple","items":[...]} + // Image -> {"type":"image","image":{...}} + let Ok(val) = serde_json::to_value(&tr.content) else { + return format!("{:?}", tr.content); + }; + + // Try text content field first (most common case). + if let Some(text) = val.get("content").and_then(|v| v.as_str()) { + return text.to_string(); + } + + // JSON content field. + if let Some(json_val) = val.get("content") + && let Ok(pretty) = serde_json::to_string_pretty(json_val) + { + return pretty; + } + + // Error message field. + if let Some(msg) = val.get("message").and_then(|v| v.as_str()) { + return format!("[error] {msg}"); + } + + // Fallback: pretty-print the whole thing. + serde_json::to_string_pretty(&val).unwrap_or_else(|_| format!("{:?}", tr.content)) +}