diff --git a/crates/lib/src/query/mod.rs b/crates/lib/src/query/mod.rs index 202eb3b..aaba861 100644 --- a/crates/lib/src/query/mod.rs +++ b/crates/lib/src/query/mod.rs @@ -1152,56 +1152,88 @@ mod tests { ); } - /// End-to-end regression test for #103: verifies that cancelling a - /// QueryEngine turn mid-stream actually interrupts the loop. - /// - /// Builds a mock provider whose stream hangs forever after one event, - /// starts a real turn, then calls `engine.cancel()` and asserts the - /// turn returns quickly. Without the tokio::select! fix, this would - /// hang until the test timeout. - #[tokio::test] - async fn run_turn_with_sink_interrupts_on_cancel() { - use crate::config::Config; - use crate::llm::provider::{Provider, ProviderError, ProviderRequest}; - use crate::permissions::PermissionChecker; - use crate::state::AppState; - use crate::tools::registry::ToolRegistry; + // ------------------------------------------------------------------ + // End-to-end regression tests for #103. + // + // These tests build a real QueryEngine with a mock Provider and + // exercise run_turn_with_sink directly, verifying that cancellation + // actually interrupts the streaming loop (not just the select! + // pattern in isolation). + // ------------------------------------------------------------------ + + use crate::llm::provider::{Provider, ProviderError, ProviderRequest}; + + /// A provider whose stream yields one TextDelta and then hangs forever. + /// Simulates the real bug: a slow LLM response the user wants to interrupt. + struct HangingProvider; + + #[async_trait::async_trait] + impl Provider for HangingProvider { + fn name(&self) -> &str { + "hanging-mock" + } - /// A provider whose stream yields one TextDelta and then hangs forever. - struct HangingProvider; + async fn stream( + &self, + _request: &ProviderRequest, + ) -> Result, ProviderError> { + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await; + // Hang forever without closing the channel or sending Done. + let _tx_holder = tx; + std::future::pending::<()>().await; + }); + Ok(rx) + } + } - #[async_trait::async_trait] - impl Provider for HangingProvider { - fn name(&self) -> &str { - "hanging-mock" - } + /// A provider that completes a turn normally: emits text and a Done event. + struct CompletingProvider; - async fn stream( - &self, - _request: &ProviderRequest, - ) -> Result, ProviderError> { - let (tx, rx) = tokio::sync::mpsc::channel(4); - tokio::spawn(async move { - let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await; - // Hang forever without closing the channel or sending Done. - // This simulates the real bug: a slow LLM response that - // the user wants to interrupt. - let _tx_holder = tx; - std::future::pending::<()>().await; - }); - Ok(rx) - } + #[async_trait::async_trait] + impl Provider for CompletingProvider { + fn name(&self) -> &str { + "completing-mock" } - let llm = Arc::new(HangingProvider); - let tools = ToolRegistry::default_tools(); + async fn stream( + &self, + _request: &ProviderRequest, + ) -> Result, ProviderError> { + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _ = tx.send(StreamEvent::TextDelta("hello".into())).await; + let _ = tx + .send(StreamEvent::ContentBlockComplete(ContentBlock::Text { + text: "hello".into(), + })) + .await; + let _ = tx + .send(StreamEvent::Done { + usage: Usage::default(), + stop_reason: Some(StopReason::EndTurn), + }) + .await; + // Channel closes when tx drops. + }); + Ok(rx) + } + } + + fn build_engine(llm: Arc) -> QueryEngine { + use crate::config::Config; + use crate::permissions::PermissionChecker; + use crate::state::AppState; + use crate::tools::registry::ToolRegistry; + let config = Config::default(); let permissions = PermissionChecker::from_config(&config.permissions); let state = AppState::new(config); - let mut engine = QueryEngine::new( + QueryEngine::new( llm, - tools, + ToolRegistry::default_tools(), permissions, state, QueryEngineConfig { @@ -1209,23 +1241,26 @@ mod tests { verbose: false, unattended: true, }, - ); + ) + } - // Clone the shared handle so the background task can cancel the - // *current* turn's token (the same path the signal handler uses). - // Cloning cancel_token() directly would capture the pre-turn token, - // which gets replaced when run_turn_with_sink starts. + /// Schedule a cancellation after `delay_ms` via the shared handle + /// (same path the signal handler uses). + fn schedule_cancel(engine: &QueryEngine, delay_ms: u64) { let shared = engine.cancel_shared.clone(); - - // Cancel the engine after a short delay so the stream has time - // to produce its first event and the loop is blocked on rx.recv(). tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; shared.lock().unwrap().cancel(); }); + } + + /// Builds a mock provider whose stream yields one TextDelta and then hangs. + /// Verifies the turn returns promptly once cancel fires. + #[tokio::test] + async fn run_turn_with_sink_interrupts_on_cancel() { + let mut engine = build_engine(Arc::new(HangingProvider)); + schedule_cancel(&engine, 100); - // Run the turn. Without the fix, this hangs forever; with the fix, - // it returns Ok(()) once cancellation is detected. let result = tokio::time::timeout( std::time::Duration::from_secs(5), engine.run_turn_with_sink("test input", &NullSink), @@ -1245,4 +1280,154 @@ mod tests { "is_query_active should be reset after cancel" ); } + + /// Regression test for the original #103 bug: the signal handler held + /// a stale clone of the cancellation token, so Ctrl+C only worked on + /// the *first* turn. This test cancels turn 1, then runs turn 2 and + /// verifies it is ALSO cancellable via the same shared handle. + #[tokio::test] + async fn cancel_works_across_multiple_turns() { + let mut engine = build_engine(Arc::new(HangingProvider)); + + // Turn 1: cancel mid-stream. + schedule_cancel(&engine, 80); + let r1 = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine.run_turn_with_sink("turn 1", &NullSink), + ) + .await; + assert!(r1.is_ok(), "turn 1 should cancel promptly"); + assert!(!engine.state().is_query_active); + + // Turn 2: cancel again via the same shared handle. + // With the pre-fix stale-token bug, the handle would be pointing + // at turn 1's already-used token and turn 2 would hang forever. + schedule_cancel(&engine, 80); + let r2 = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine.run_turn_with_sink("turn 2", &NullSink), + ) + .await; + assert!( + r2.is_ok(), + "turn 2 should also cancel promptly — regression would hang here" + ); + assert!(!engine.state().is_query_active); + + // Turn 3: one more for good measure. + schedule_cancel(&engine, 80); + let r3 = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine.run_turn_with_sink("turn 3", &NullSink), + ) + .await; + assert!(r3.is_ok(), "turn 3 should still be cancellable"); + assert!(!engine.state().is_query_active); + } + + /// Verifies that a previously-cancelled token does not poison subsequent + /// turns. A fresh run_turn_with_sink on the same engine should complete + /// normally even after a prior cancel. + #[tokio::test] + async fn cancel_does_not_poison_next_turn() { + // Turn 1: hangs and gets cancelled. + let mut engine = build_engine(Arc::new(HangingProvider)); + schedule_cancel(&engine, 80); + let _ = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine.run_turn_with_sink("turn 1", &NullSink), + ) + .await + .expect("turn 1 should cancel"); + + // Swap the provider to one that completes normally by rebuilding + // the engine (we can't swap llm on an existing engine, so this + // simulates the isolated "fresh turn" behavior). The key property + // being tested is that the per-turn cancel reset correctly + // initializes a non-cancelled token. + let mut engine2 = build_engine(Arc::new(CompletingProvider)); + + // Pre-cancel engine2 to simulate a leftover cancelled state, then + // verify run_turn_with_sink still works because it resets the token. + engine2.cancel_shared.lock().unwrap().cancel(); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine2.run_turn_with_sink("hello", &NullSink), + ) + .await; + + assert!(result.is_ok(), "completing turn should not hang"); + assert!( + result.unwrap().is_ok(), + "turn should succeed — the stale cancel flag must be cleared on turn start" + ); + // Message history should contain the user + assistant messages. + assert!( + engine2.state().messages.len() >= 2, + "normal turn should push both user and assistant messages" + ); + } + + /// Verifies that cancelling BEFORE any event arrives still interrupts + /// the turn cleanly (edge case: cancellation races with the first recv). + #[tokio::test] + async fn cancel_before_first_event_interrupts_cleanly() { + let mut engine = build_engine(Arc::new(HangingProvider)); + // Very short delay so cancel likely fires before or during the + // first event. The test is tolerant of either ordering. + schedule_cancel(&engine, 1); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine.run_turn_with_sink("immediate", &NullSink), + ) + .await; + + assert!(result.is_ok(), "early cancel should not hang"); + assert!(result.unwrap().is_ok()); + assert!(!engine.state().is_query_active); + } + + /// Verifies the sink receives cancellation feedback via on_warning. + #[tokio::test] + async fn cancelled_turn_emits_warning_to_sink() { + use std::sync::Mutex; + + /// Captures sink events for assertion. + struct CapturingSink { + warnings: Mutex>, + } + + impl StreamSink for CapturingSink { + fn on_text(&self, _: &str) {} + fn on_tool_start(&self, _: &str, _: &serde_json::Value) {} + fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {} + fn on_error(&self, _: &str) {} + fn on_warning(&self, msg: &str) { + self.warnings.lock().unwrap().push(msg.to_string()); + } + } + + let sink = CapturingSink { + warnings: Mutex::new(Vec::new()), + }; + + let mut engine = build_engine(Arc::new(HangingProvider)); + schedule_cancel(&engine, 100); + + let _ = tokio::time::timeout( + std::time::Duration::from_secs(5), + engine.run_turn_with_sink("test", &sink), + ) + .await + .expect("should not hang"); + + let warnings = sink.warnings.lock().unwrap(); + assert!( + warnings.iter().any(|w| w.contains("Cancelled")), + "expected 'Cancelled' warning in sink, got: {:?}", + *warnings + ); + } }