Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/reloaded-code-serdesai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ mistral = ["serdes-ai-models/mistral"]
ollama = ["serdes-ai-models/ollama"]
openrouter = ["serdes-ai-models/openrouter"]
# Sandbox feature - enables bubblewrap sandboxing
linux-bubblewrap = [
"dep:reloaded-code-bubblewrap",
"reloaded-code-core/linux-bubblewrap",
]
linux-bubblewrap = ["dep:reloaded-code-bubblewrap", "reloaded-code-core/linux-bubblewrap"]

# Mock feature - enables mock models types and model_override injection
# Use for testing functionality with mocks.
mock = []

[dependencies]
# Core tool operations (file read/write/edit, glob, grep, bash, etc.)
Expand Down
51 changes: 49 additions & 2 deletions src/reloaded-code-serdesai/src/agent_runtime/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::task::TaskHandle;
use reloaded_code_agents::AgentRuntime;
use reloaded_code_core::{CredentialLookup, CredentialResolver, models::ModelCatalog};
use serdes_ai::{Agent, AgentBuilder};
#[cfg(any(test, feature = "mock"))]
use serdes_ai_models::BoxedModel;
use std::path::Path;
use std::sync::Arc;

Expand Down Expand Up @@ -58,6 +60,8 @@ where
model_catalog,
credentials,
workspace_root,
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -198,6 +202,26 @@ where
pub fn credentials(&self) -> &C {
self.context.credentials.as_ref()
}

/// Sets a mock model that overrides the resolved catalog model.
///
/// # Arguments
/// - `model`: Any [`serdes_ai_models::Model`] implementation to use instead
/// of the catalog-resolved model.
///
/// # Returns
/// `Self` for chaining.
///
/// # Panics
/// Panics if the [`AgentBuildContext`] has already been cloned (i.e., the
/// inner `Arc` is not unique). This must be called before sharing the context.
#[cfg(any(test, feature = "mock"))]
pub fn with_model_override(mut self, model: impl serdes_ai_models::Model + 'static) -> Self {
Arc::get_mut(&mut self.context)
.expect("with_model_override must be called before sharing the context")
.model_override = Some(Arc::new(model));
self
}
}

/// Shared owned state for builds that may happen later during Task delegation.
Expand All @@ -208,6 +232,8 @@ pub(crate) struct TaskBuildContext<C: CredentialLookup + Send + Sync + ?Sized =
model_catalog: Arc<ModelCatalog>,
credentials: Arc<C>,
workspace_root: Arc<Path>,
#[cfg(any(test, feature = "mock"))]
model_override: Option<BoxedModel>,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: Option<Arc<Profile>>,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -251,6 +277,8 @@ where
model_catalog,
credentials,
workspace_root,
#[cfg(any(test, feature = "mock"))]
model_override: None,
bash_sandbox: Some(bash_sandbox),
_sandbox_tmpdir,
}
Expand All @@ -274,6 +302,8 @@ where
model_catalog,
credentials,
workspace_root,
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -329,8 +359,15 @@ where
context.credentials.as_ref(),
with_summaries,
)?;
// Create an AgentBuilder pre-loaded with the resolved model.
let builder = AgentBuilder::<(), String>::from_arc(prepared.model().clone());
// Create an AgentBuilder with the model (override wins over catalog-resolved).
#[cfg(any(test, feature = "mock"))]
let model = context
.model_override
.clone()
.unwrap_or_else(|| prepared.model().clone());
#[cfg(not(any(test, feature = "mock")))]
let model = prepared.model().clone();
let builder = AgentBuilder::<(), String>::from_arc(model);
// Create a TaskHandle for delegation if Task tool is attached later.
let task_handle = TaskHandle::new(context.clone(), current_depth);
// Select the sandbox profile (None on non-Linux or without the feature).
Expand Down Expand Up @@ -464,6 +501,8 @@ mod tests {
model_catalog,
credentials,
workspace_root: workspace_root(),
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -504,6 +543,8 @@ mod tests {
model_catalog,
credentials,
workspace_root: workspace_root(),
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -543,6 +584,8 @@ mod tests {
model_catalog,
credentials,
workspace_root: workspace_root(),
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -578,6 +621,8 @@ mod tests {
model_catalog,
credentials,
workspace_root: workspace_root(),
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down Expand Up @@ -650,6 +695,8 @@ mod tests {
model_catalog,
credentials,
workspace_root: workspace_root(),
#[cfg(any(test, feature = "mock"))]
model_override: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
bash_sandbox: None,
#[cfg(all(feature = "linux-bubblewrap", target_os = "linux"))]
Expand Down
3 changes: 3 additions & 0 deletions src/reloaded-code-serdesai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ pub use reloaded_code_agents::{
AgentDefaults, AgentRuntime, AgentRuntimeBuilder, ModelResolutionError, ResolvedModel,
resolve_model_with_catalog,
};

#[cfg(any(test, feature = "mock"))]
pub mod mock;
235 changes: 235 additions & 0 deletions src/reloaded-code-serdesai/src/mock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
//! Mock model types with streaming support for running agents without a real LLM provider.
//!
//! Wraps upstream [`serdes_ai_models`] mock types so they work with
//! [`Agent::run_stream`][`serdes_ai::Agent::run_stream`].
//!
//! # Quick start
//!
//! ```text
//! use reloaded_code_serdesai::mock::{Streamed, tool_then_text};
//! use serde_json::json;
//!
//! let model = tool_then_text("glob", json!({"pattern": "*.rs"}), "Done.");
//! let stream = agent.run_stream("prompt", ()).await?; // OK
//! ```
//!
//! When using [`crate::AgentBuildContext`], call
//! [`with_model_override`](crate::AgentBuildContext::with_model_override)
//! to inject the mock model before calling [`build()`](crate::AgentBuildContext::build).

// Re-export upstream mock types so users can still access the raw variants when needed.
pub use serdes_ai_models::{FunctionModel, MockModel, TestModel};

use async_trait::async_trait;
use futures::stream;
use serdes_ai::core::{
FinishReason, ModelRequest, ModelResponse, ModelResponsePart, ModelResponseStreamEvent,
};
use serdes_ai_models::Model as ModelTrait;
// Re-export the types from where serdes-ai-models exposes them.
use serdes_ai::core::ModelSettings;
use serdes_ai_models::{
ModelCapability, ModelError, ModelProfile, ModelRequestParameters, StreamedResponse,
};

// ============================================================================
// Streamed - wrapper that adds streaming support to any Model
// ============================================================================

/// Wrapper adding [`request_stream`](ModelTrait::request_stream) support to any [`ModelTrait`] implementation.
///
/// Delegates [`request`](ModelTrait::request) directly to the inner model and converts the non-streaming
/// response into a sequence of [`ModelResponseStreamEvent`]s for streaming callers.
///
/// # Example
///
/// ```rust,no_run
/// use reloaded_code_serdesai::mock::{FunctionModel, Streamed};
/// use serde_json::json;
///
/// let model = Streamed::new(FunctionModel::tool_call("glob", json!({"pattern": "*.rs"})));
/// ```
#[derive(Clone, Debug)]
pub struct Streamed<T> {
inner: T,
name: String,
}

impl<T> Streamed<T> {
/// Wrap a model to add streaming support.
///
/// The `name` defaults to the inner model's [`name()`](ModelTrait::name).
pub fn new(inner: T) -> Self
where
T: ModelTrait,
{
let name = inner.name().to_string();
Self { inner, name }
}

/// Set a custom name for the wrapped model.
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
}

#[async_trait]
impl<T: ModelTrait + Send + Sync> ModelTrait for Streamed<T> {
fn name(&self) -> &str {
&self.name
}

fn system(&self) -> &str {
self.inner.system()
}

fn profile(&self) -> &ModelProfile {
self.inner.profile()
}

async fn request(
&self,
messages: &[ModelRequest],
settings: &ModelSettings,
params: &ModelRequestParameters,
) -> Result<ModelResponse, ModelError> {
self.inner.request(messages, settings, params).await
}

async fn request_stream(
&self,
messages: &[ModelRequest],
settings: &ModelSettings,
params: &ModelRequestParameters,
) -> Result<StreamedResponse, ModelError> {
let response = self.inner.request(messages, settings, params).await?;
let events = response_to_stream_events(response);
Ok(Box::pin(stream::iter(events.into_iter().map(Ok))))
}

fn supports(&self, capability: ModelCapability) -> bool {
self.inner.supports(capability)
}
}

// ============================================================================
// Convenience helpers
// ============================================================================

/// Build a mock model that calls `tool_name` with `args` on the **first** turn,
/// then returns text that incorporates the real tool return on the **second** turn.
///
/// This prevents infinite loops when running agent examples that stream,
/// because after the tool result is fed back the model answers with text.
///
/// The second-turn response includes whatever the real tool returned, so
/// the output reflects actual tool execution rather than a canned message.
///
/// # Example
///
/// ```rust,no_run
/// use reloaded_code_serdesai::mock::tool_then_text;
/// use serde_json::json;
///
/// let model = tool_then_text("glob", json!({"pattern": "*.rs"}), "Done.");
/// ```
pub fn tool_then_text(
tool_name: impl Into<String>,
args: serde_json::Value,
fallback_text: impl Into<String>,
) -> Streamed<FunctionModel> {
let tool_name = tool_name.into();
let fallback_text = fallback_text.into();
let tool_name_clone = tool_name.clone();

let model = FunctionModel::new(move |messages, _settings| {
// Check whether the conversation already contains a tool return from a
// previous turn. If it does, we are on the second call and should
// produce a text response incorporating the real result.
let has_tool_return = messages.iter().any(|m| {
m.parts
.iter()
.any(|p| matches!(p, serdes_ai::core::ModelRequestPart::ToolReturn(_)))
});

if has_tool_return {
// Collect tool return content from the message history.
let tool_results: String = messages
.iter()
.flat_map(|m| m.tool_returns())
.map(extract_tool_return_text)
.collect::<Vec<_>>()
.join("\n");

let text = if tool_results.is_empty() {
fallback_text.clone()
} else {
format!("{fallback_text}\n\n{tool_results}")
};

ModelResponse::text(text)
} else {
// First call: emit a tool call so the agent executes the real tool.
ModelResponse::with_parts(vec![
ModelResponsePart::text(format!("Calling {tool_name}...")),
ModelResponsePart::tool_call(tool_name_clone.clone(), args.clone()),
])
.with_finish_reason(FinishReason::ToolCall)
}
});

Streamed::new(model)
}

// ============================================================================
// Private helpers
// ============================================================================

fn response_to_stream_events(response: ModelResponse) -> Vec<ModelResponseStreamEvent> {
let mut events = Vec::with_capacity(response.parts.len() * 2 + 1);

for (index, part) in response.parts.into_iter().enumerate() {
events.push(ModelResponseStreamEvent::part_start(index, part));
events.push(ModelResponseStreamEvent::part_end(index));
}

events
}

/// Extract human-readable text from a [`ToolReturnPart`].
///
/// Uses serde JSON round-tripping to avoid depending on the
/// non-public `ToolReturnContent` enum variants directly.
fn extract_tool_return_text(tr: &serdes_ai::core::ToolReturnPart) -> String {
// Serialize the content field to JSON, then extract readable text.
// ToolReturnContent variants produce:
// Text -> {"type":"text","content":"..."}
// Json -> {"type":"json","content":{...}}
// Error -> {"type":"error","message":"..."}
// Multiple-> {"type":"multiple","items":[...]}
// Image -> {"type":"image","image":{...}}
let Ok(val) = serde_json::to_value(&tr.content) else {
return format!("{:?}", tr.content);
};

// Try text content field first (most common case).
if let Some(text) = val.get("content").and_then(|v| v.as_str()) {
return text.to_string();
}

// JSON content field.
if let Some(json_val) = val.get("content")
&& let Ok(pretty) = serde_json::to_string_pretty(json_val)
{
return pretty;
}

// Error message field.
if let Some(msg) = val.get("message").and_then(|v| v.as_str()) {
return format!("[error] {msg}");
}

// Fallback: pretty-print the whole thing.
serde_json::to_string_pretty(&val).unwrap_or_else(|_| format!("{:?}", tr.content))
}
Loading