Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 236 additions & 51 deletions crates/lib/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,80 +1152,115 @@ mod tests {
);
}

/// End-to-end regression test for #103: verifies that cancelling a
/// QueryEngine turn mid-stream actually interrupts the loop.
///
/// Builds a mock provider whose stream hangs forever after one event,
/// starts a real turn, then calls `engine.cancel()` and asserts the
/// turn returns quickly. Without the tokio::select! fix, this would
/// hang until the test timeout.
#[tokio::test]
async fn run_turn_with_sink_interrupts_on_cancel() {
use crate::config::Config;
use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
use crate::permissions::PermissionChecker;
use crate::state::AppState;
use crate::tools::registry::ToolRegistry;
// ------------------------------------------------------------------
// End-to-end regression tests for #103.
//
// These tests build a real QueryEngine with a mock Provider and
// exercise run_turn_with_sink directly, verifying that cancellation
// actually interrupts the streaming loop (not just the select!
// pattern in isolation).
// ------------------------------------------------------------------

use crate::llm::provider::{Provider, ProviderError, ProviderRequest};

/// A provider whose stream yields one TextDelta and then hangs forever.
/// Simulates the real bug: a slow LLM response the user wants to interrupt.
struct HangingProvider;

#[async_trait::async_trait]
impl Provider for HangingProvider {
fn name(&self) -> &str {
"hanging-mock"
}

/// A provider whose stream yields one TextDelta and then hangs forever.
struct HangingProvider;
async fn stream(
&self,
_request: &ProviderRequest,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
let (tx, rx) = tokio::sync::mpsc::channel(4);
tokio::spawn(async move {
let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
// Hang forever without closing the channel or sending Done.
let _tx_holder = tx;
std::future::pending::<()>().await;
});
Ok(rx)
}
}

#[async_trait::async_trait]
impl Provider for HangingProvider {
fn name(&self) -> &str {
"hanging-mock"
}
/// A provider that completes a turn normally: emits text and a Done event.
struct CompletingProvider;

async fn stream(
&self,
_request: &ProviderRequest,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
let (tx, rx) = tokio::sync::mpsc::channel(4);
tokio::spawn(async move {
let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
// Hang forever without closing the channel or sending Done.
// This simulates the real bug: a slow LLM response that
// the user wants to interrupt.
let _tx_holder = tx;
std::future::pending::<()>().await;
});
Ok(rx)
}
#[async_trait::async_trait]
impl Provider for CompletingProvider {
fn name(&self) -> &str {
"completing-mock"
}

let llm = Arc::new(HangingProvider);
let tools = ToolRegistry::default_tools();
async fn stream(
&self,
_request: &ProviderRequest,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
let (tx, rx) = tokio::sync::mpsc::channel(8);
tokio::spawn(async move {
let _ = tx.send(StreamEvent::TextDelta("hello".into())).await;
let _ = tx
.send(StreamEvent::ContentBlockComplete(ContentBlock::Text {
text: "hello".into(),
}))
.await;
let _ = tx
.send(StreamEvent::Done {
usage: Usage::default(),
stop_reason: Some(StopReason::EndTurn),
})
.await;
// Channel closes when tx drops.
});
Ok(rx)
}
}

fn build_engine(llm: Arc<dyn Provider>) -> QueryEngine {
use crate::config::Config;
use crate::permissions::PermissionChecker;
use crate::state::AppState;
use crate::tools::registry::ToolRegistry;

let config = Config::default();
let permissions = PermissionChecker::from_config(&config.permissions);
let state = AppState::new(config);

let mut engine = QueryEngine::new(
QueryEngine::new(
llm,
tools,
ToolRegistry::default_tools(),
permissions,
state,
QueryEngineConfig {
max_turns: Some(1),
verbose: false,
unattended: true,
},
);
)
}

// Clone the shared handle so the background task can cancel the
// *current* turn's token (the same path the signal handler uses).
// Cloning cancel_token() directly would capture the pre-turn token,
// which gets replaced when run_turn_with_sink starts.
/// Schedule a cancellation after `delay_ms` via the shared handle
/// (same path the signal handler uses).
fn schedule_cancel(engine: &QueryEngine, delay_ms: u64) {
let shared = engine.cancel_shared.clone();

// Cancel the engine after a short delay so the stream has time
// to produce its first event and the loop is blocked on rx.recv().
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
shared.lock().unwrap().cancel();
});
}

/// Builds a mock provider whose stream yields one TextDelta and then hangs.
/// Verifies the turn returns promptly once cancel fires.
#[tokio::test]
async fn run_turn_with_sink_interrupts_on_cancel() {
let mut engine = build_engine(Arc::new(HangingProvider));
schedule_cancel(&engine, 100);

// Run the turn. Without the fix, this hangs forever; with the fix,
// it returns Ok(()) once cancellation is detected.
let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("test input", &NullSink),
Expand All @@ -1245,4 +1280,154 @@ mod tests {
"is_query_active should be reset after cancel"
);
}

/// Regression test for the original #103 bug: the signal handler held
/// a stale clone of the cancellation token, so Ctrl+C only worked on
/// the *first* turn. This test cancels turn 1, then runs turn 2 and
/// verifies it is ALSO cancellable via the same shared handle.
#[tokio::test]
async fn cancel_works_across_multiple_turns() {
let mut engine = build_engine(Arc::new(HangingProvider));

// Turn 1: cancel mid-stream.
schedule_cancel(&engine, 80);
let r1 = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("turn 1", &NullSink),
)
.await;
assert!(r1.is_ok(), "turn 1 should cancel promptly");
assert!(!engine.state().is_query_active);

// Turn 2: cancel again via the same shared handle.
// With the pre-fix stale-token bug, the handle would be pointing
// at turn 1's already-used token and turn 2 would hang forever.
schedule_cancel(&engine, 80);
let r2 = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("turn 2", &NullSink),
)
.await;
assert!(
r2.is_ok(),
"turn 2 should also cancel promptly — regression would hang here"
);
assert!(!engine.state().is_query_active);

// Turn 3: one more for good measure.
schedule_cancel(&engine, 80);
let r3 = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("turn 3", &NullSink),
)
.await;
assert!(r3.is_ok(), "turn 3 should still be cancellable");
assert!(!engine.state().is_query_active);
}

/// Verifies that a previously-cancelled token does not poison subsequent
/// turns. A fresh run_turn_with_sink on the same engine should complete
/// normally even after a prior cancel.
#[tokio::test]
async fn cancel_does_not_poison_next_turn() {
// Turn 1: hangs and gets cancelled.
let mut engine = build_engine(Arc::new(HangingProvider));
schedule_cancel(&engine, 80);
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("turn 1", &NullSink),
)
.await
.expect("turn 1 should cancel");

// Swap the provider to one that completes normally by rebuilding
// the engine (we can't swap llm on an existing engine, so this
// simulates the isolated "fresh turn" behavior). The key property
// being tested is that the per-turn cancel reset correctly
// initializes a non-cancelled token.
let mut engine2 = build_engine(Arc::new(CompletingProvider));

// Pre-cancel engine2 to simulate a leftover cancelled state, then
// verify run_turn_with_sink still works because it resets the token.
engine2.cancel_shared.lock().unwrap().cancel();

let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine2.run_turn_with_sink("hello", &NullSink),
)
.await;

assert!(result.is_ok(), "completing turn should not hang");
assert!(
result.unwrap().is_ok(),
"turn should succeed — the stale cancel flag must be cleared on turn start"
);
// Message history should contain the user + assistant messages.
assert!(
engine2.state().messages.len() >= 2,
"normal turn should push both user and assistant messages"
);
}

/// Verifies that cancelling BEFORE any event arrives still interrupts
/// the turn cleanly (edge case: cancellation races with the first recv).
#[tokio::test]
async fn cancel_before_first_event_interrupts_cleanly() {
let mut engine = build_engine(Arc::new(HangingProvider));
// Very short delay so cancel likely fires before or during the
// first event. The test is tolerant of either ordering.
schedule_cancel(&engine, 1);

let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("immediate", &NullSink),
)
.await;

assert!(result.is_ok(), "early cancel should not hang");
assert!(result.unwrap().is_ok());
assert!(!engine.state().is_query_active);
}

/// Verifies the sink receives cancellation feedback via on_warning.
#[tokio::test]
async fn cancelled_turn_emits_warning_to_sink() {
use std::sync::Mutex;

/// Captures sink events for assertion.
struct CapturingSink {
warnings: Mutex<Vec<String>>,
}

impl StreamSink for CapturingSink {
fn on_text(&self, _: &str) {}
fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
fn on_error(&self, _: &str) {}
fn on_warning(&self, msg: &str) {
self.warnings.lock().unwrap().push(msg.to_string());
}
}

let sink = CapturingSink {
warnings: Mutex::new(Vec::new()),
};

let mut engine = build_engine(Arc::new(HangingProvider));
schedule_cancel(&engine, 100);

let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
engine.run_turn_with_sink("test", &sink),
)
.await
.expect("should not hang");

let warnings = sink.warnings.lock().unwrap();
assert!(
warnings.iter().any(|w| w.contains("Cancelled")),
"expected 'Cancelled' warning in sink, got: {:?}",
*warnings
);
}
}
Loading