diff --git a/src/agent.rs b/src/agent.rs index d0c8f6679..e7b4f03ba 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -28,6 +28,7 @@ use crate::build; use crate::bus::{Bus, BusEvent, SubagentStatus, ToolEvent, ToolStatus}; use crate::cache_tracker::CacheTracker; use crate::compaction::CompactionEvent; +use crate::prefix_cache_stable; use crate::id; use crate::logging; use crate::message::{ @@ -525,6 +526,19 @@ impl Agent { self.note_compaction_applied(); self.persist_session_best_effort("compaction completion"); } + let messages = if prefix_cache_stable::is_prefix_cache_stable_mode() { + let (truncated, truncate_count) = + prefix_cache_stable::truncate_tool_results_for_api(&messages); + if truncate_count > 0 { + logging::info(&format!( + "Prefix-cache mode: truncated {} tool results for API", + truncate_count + )); + } + truncated + } else { + messages + }; let user_count = messages .iter() .filter(|message| matches!(message.role, Role::User)) @@ -546,6 +560,19 @@ impl Agent { let all_messages = self.session.provider_messages(); let messages = all_messages.to_vec(); + let messages = if prefix_cache_stable::is_prefix_cache_stable_mode() { + let (truncated, truncate_count) = + prefix_cache_stable::truncate_tool_results_for_api(&messages); + if truncate_count > 0 { + logging::info(&format!( + "Prefix-cache mode: truncated {} tool results for API (session path)", + truncate_count + )); + } + truncated + } else { + messages + }; let user_count = messages .iter() .filter(|message| matches!(message.role, Role::User)) diff --git a/src/agent/turn_loops.rs b/src/agent/turn_loops.rs index aa35525d9..a35749624 100644 --- a/src/agent/turn_loops.rs +++ b/src/agent/turn_loops.rs @@ -49,6 +49,29 @@ impl Agent { // false-positive violations every turn (prior turn's memory ≠ current history prefix). self.record_client_cache_request(&messages); + // Preflight check for DeepSeek prefix-cache stability mode. + // If the payload is near the context limit, warn early so compaction + // can fold history before the API returns a 400. + if prefix_cache_stable::is_prefix_cache_stable_mode() { + let preflight = + prefix_cache_stable::preflight_check(&messages, &tools, &self.provider.model()); + if preflight.needs_action { + logging::warn(&format!( + "Prefix-cache preflight: context at {:.1}% ({} / {} tokens) — emergency fold recommended", + preflight.ratio * 100.0, + preflight.estimate_tokens, + preflight.ctx_max, + )); + } else if preflight.ratio > 0.5 { + logging::info(&format!( + "Prefix-cache preflight: context at {:.1}% ({} / {} tokens)", + preflight.ratio * 100.0, + preflight.estimate_tokens, + preflight.ctx_max, + )); + } + } + // Inject memory as a user message at the end (preserves cache prefix) let mut messages_with_memory: Vec = messages.iter().cloned().collect(); if let Some(memory) = memory_pending.as_ref() { @@ -503,6 +526,13 @@ impl Agent { usage_cache_read, usage_cache_creation, ); + // Record cache usage for prefix-cache hit-rate tracking + self.cache_tracker.record_usage(usage_cache_read, usage_input.unwrap_or(0)); + if prefix_cache_stable::is_prefix_cache_stable_mode() + && self.cache_tracker.usage_turn_count() % 5 == 0 + { + logging::info(&format!("Prefix-cache stats: {}", self.cache_tracker.cache_hit_summary())); + } } if print_output diff --git a/src/cache_tracker.rs b/src/cache_tracker.rs index 8061fee0c..724319ceb 100644 --- a/src/cache_tracker.rs +++ b/src/cache_tracker.rs @@ -26,6 +26,12 @@ pub struct CacheTracker { hash_history: VecDeque, /// Whether append-only was violated on the last request last_violation: Option, + /// Cumulative cache hit tokens (from provider-reported usage) + cache_hit_tokens: u64, + /// Cumulative cache miss tokens (from provider-reported usage) + cache_miss_tokens: u64, + /// Number of turns with usage data recorded + usage_turns: u32, } /// Information about a cache violation @@ -207,6 +213,59 @@ impl CacheTracker { pub fn had_violation(&self) -> bool { self.last_violation.is_some() } + + /// Record provider-reported cache usage for cache-hit-rate tracking. + /// Call this after each successful API response when usage data is available. + pub fn record_usage(&mut self, cache_read_input_tokens: Option, input_tokens: u64) { + if let Some(hit) = cache_read_input_tokens { + self.cache_hit_tokens += hit; + // Miss tokens = total input minus cache hits + let miss = input_tokens.saturating_sub(hit); + self.cache_miss_tokens += miss; + } else { + // Provider doesn't report cache hits; count all as miss + self.cache_miss_tokens += input_tokens; + } + self.usage_turns += 1; + } + + /// Cumulative cache hit tokens + pub fn cache_hit_tokens(&self) -> u64 { + self.cache_hit_tokens + } + + /// Cumulative cache miss tokens + pub fn cache_miss_tokens(&self) -> u64 { + self.cache_miss_tokens + } + + /// Number of turns with usage data recorded + pub fn usage_turn_count(&self) -> u32 { + self.usage_turns + } + + /// Cache hit rate as a ratio (0.0–1.0), or None if no usage recorded + pub fn cache_hit_rate(&self) -> Option { + let total = self.cache_hit_tokens + self.cache_miss_tokens; + if total == 0 { + return None; + } + Some(self.cache_hit_tokens as f64 / total as f64) + } + + /// Human-readable cache hit summary + pub fn cache_hit_summary(&self) -> String { + match self.cache_hit_rate() { + None => "no cache usage data yet".to_string(), + Some(rate) => format!( + "cache hit: {:.1}% ({} hit / {} miss tokens over {} turns)", + rate * 100.0, + self.cache_hit_tokens, + self.cache_miss_tokens, + self.usage_turns + ), + } + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 14a651cba..a0f8ee0f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ pub mod build; pub mod bus; pub mod cache_tracker; pub mod catchup; +pub mod prefix_cache_stable; pub mod channel; pub mod cli; pub mod compaction; diff --git a/src/prefix_cache_stable.rs b/src/prefix_cache_stable.rs new file mode 100644 index 000000000..dafdfbb45 --- /dev/null +++ b/src/prefix_cache_stable.rs @@ -0,0 +1,336 @@ +//! DeepSeek prefix cache stability design (adapted from Reasonix Pillar 1) +//! +//! When profile=deepseek, this module ensures the message prefix stays byte-stable +//! across turns, maximizing DeepSeek's automatic prefix cache hit rate. +//! +//! Core invariants (adapted from Reasonix): +//! 1. Immutable Prefix: system prompt + tool specs + few-shots are fixed per session +//! 2. Append-Only Log: messages grow monotonically; no rewrites of prior turns +//! 3. Volatile Scratch: reasoning content is transient, never sent upstream +//! 4. Preflight Check: local token estimation catches oversized payloads before sending +//! 5. Turn-End Truncation: oversized tool results are shrunk at turn end + +use crate::message::{ContentBlock, Message, Role, ToolDefinition}; + +/// DeepSeek V4 context window (1M tokens for direct API) +const DEEPSEEK_V4_CONTEXT_TOKENS: usize = 1_000_000; + +/// Default context window fallback +const DEFAULT_CONTEXT_TOKENS: usize = 128_000; + +/// Threshold at which we consider folding history +const HISTORY_FOLD_THRESHOLD: f64 = 0.5; +/// Tail budget after a normal fold, as fraction of ctx_max +const HISTORY_FOLD_TAIL_FRACTION: f64 = 0.2; +/// Aggressive fold threshold +const HISTORY_FOLD_AGGRESSIVE_THRESHOLD: f64 = 0.7; +/// Aggressive tail fraction +const HISTORY_FOLD_AGGRESSIVE_TAIL_FRACTION: f64 = 0.1; +/// Force summary exit threshold +const FORCE_SUMMARY_THRESHOLD: f64 = 0.8; +/// Emergency preflight threshold +const PREFLIGHT_EMERGENCY_THRESHOLD: f64 = 0.95; +/// Turn-end tool result cap in tokens +const TURN_END_RESULT_CAP_TOKENS: usize = 3000; +/// Max chars for a tool result after turn-end truncation (chars / 4 heuristic) +const TURN_END_RESULT_CAP_CHARS: usize = TURN_END_RESULT_CAP_TOKENS * 4; + +/// Detect if prefix cache stable mode should be active. +/// +/// This checks multiple signals that indicate the user is running against +/// DeepSeek's API (direct or via OpenRouter), where prefix-cache mechanics +/// differ from Anthropic's explicit cache-control. +pub fn is_prefix_cache_stable_mode() -> bool { + // Primary: the OpenRouter/OpenAI-compatible cache namespace is set to deepseek + if let Ok(namespace) = std::env::var("JCODE_OPENROUTER_CACHE_NAMESPACE") { + if namespace.trim().eq_ignore_ascii_case("deepseek") { + return true; + } + } + // Secondary: runtime provider hint + if let Ok(provider) = std::env::var("JCODE_RUNTIME_PROVIDER") { + if provider.trim().eq_ignore_ascii_case("deepseek") { + return true; + } + } + // Tertiary: named provider profile active + if let Ok(profile) = std::env::var("JCODE_NAMED_PROVIDER_PROFILE") { + if profile.trim().eq_ignore_ascii_case("deepseek") { + return true; + } + } + false +} + +/// Preflight decision before sending a request. +#[derive(Debug, Clone)] +pub struct PreflightDecision { + /// Whether action is needed (compact or abort) + pub needs_action: bool, + /// Estimated token count + pub estimate_tokens: usize, + /// Context window size + pub ctx_max: usize, + /// Ratio of estimate to ctx_max + pub ratio: f64, +} + +/// Local-side preflight before sending a request — catches oversized payloads early. +/// +/// Adapted from Reasonix ContextManager::decidePreflight. +pub fn preflight_check(messages: &[Message], tools: &[ToolDefinition], model: &str) -> PreflightDecision { + let ctx_max = context_tokens_for_model(model); + let estimate = estimate_request_tokens(messages, tools); + let ratio = if ctx_max > 0 { + estimate as f64 / ctx_max as f64 + } else { + 0.0 + }; + PreflightDecision { + needs_action: ratio > PREFLIGHT_EMERGENCY_THRESHOLD, + estimate_tokens: estimate, + ctx_max, + ratio, + } +} + +/// Post-usage decision after a turn's response. +#[derive(Debug, Clone)] +pub enum PostUsageAction { + /// No action needed + None, + /// Fold history, keeping recent tail within budget + Fold { tail_budget: usize, aggressive: bool }, + /// Exit turn with a forced summary + ExitWithSummary, +} + +/// Decide what to do after receiving usage data from the provider. +/// +/// Adapted from Reasonix ContextManager::decideAfterUsage. +pub fn decide_after_usage( + prompt_tokens: usize, + model: &str, + already_folded_this_turn: bool, +) -> PostUsageAction { + let ctx_max = context_tokens_for_model(model); + if ctx_max == 0 { + return PostUsageAction::None; + } + let ratio = prompt_tokens as f64 / ctx_max as f64; + + if ratio > FORCE_SUMMARY_THRESHOLD { + return PostUsageAction::ExitWithSummary; + } + if already_folded_this_turn { + return PostUsageAction::None; + } + if ratio > HISTORY_FOLD_AGGRESSIVE_THRESHOLD { + return PostUsageAction::Fold { + tail_budget: (ctx_max as f64 * HISTORY_FOLD_AGGRESSIVE_TAIL_FRACTION) as usize, + aggressive: true, + }; + } + if ratio > HISTORY_FOLD_THRESHOLD { + return PostUsageAction::Fold { + tail_budget: (ctx_max as f64 * HISTORY_FOLD_TAIL_FRACTION) as usize, + aggressive: false, + }; + } + PostUsageAction::None +} + +/// Estimate tokens for a request (messages + tools + overhead). +/// +/// Uses jcode's standard chars/4 heuristic plus system/tool overhead. +pub fn estimate_request_tokens(messages: &[Message], tools: &[ToolDefinition]) -> usize { + use jcode_compaction_core::{CHARS_PER_TOKEN, DEFAULT_TOKEN_BUDGET}; + + let msg_chars: usize = messages + .iter() + .map(jcode_compaction_core::message_char_count) + .sum(); + + let tool_chars = ToolDefinition::aggregate_prompt_chars(tools); + let total_chars = msg_chars + tool_chars; + let msg_tokens = total_chars / CHARS_PER_TOKEN; + + // Conservative overhead for system prompt + tool definitions. + // SYSTEM_OVERHEAD_TOKENS (18k) is calibrated for Anthropic-sized system + // prompts; DeepSeek/OpenAI-compatible paths are typically smaller. + let overhead = if DEFAULT_TOKEN_BUDGET >= 32000 { + 1_000 + } else { + 200 + }; + + msg_tokens + overhead +} + +/// Truncate oversized tool results when preparing messages for the API. +/// +/// Every tool result exceeding the cap is shrunk so that subsequent turns +/// do not pay full price to re-read it. The model had the full text on the +/// turn that originally received it; later turns see a compact reminder. +/// +/// Returns the number of tool results that were truncated. Operates on cloned +/// messages so the original session history is untouched. +pub fn truncate_tool_results_for_api(messages: &[Message]) -> (Vec, usize) { + let mut truncated_count = 0usize; + let mut result = Vec::with_capacity(messages.len()); + + for msg in messages { + let mut new_msg = msg.clone(); + if matches!(new_msg.role, Role::User) { + for block in new_msg.content.iter_mut() { + if let ContentBlock::ToolResult { content, .. } = block { + if content.len() > TURN_END_RESULT_CAP_CHARS { + let original_len = content.len(); + let truncated_text = + crate::util::truncate_str(content, TURN_END_RESULT_CAP_CHARS); + *content = format!( + "{}\n\n[truncated from {} chars — re-read the source if full output needed]", + truncated_text, original_len + ); + truncated_count += 1; + } + } + } + } + result.push(new_msg); + } + + (result, truncated_count) +} + +/// Context window size for a given model. +pub fn context_tokens_for_model(model: &str) -> usize { + let model = model.trim().to_ascii_lowercase(); + if model.starts_with("deepseek-v4-") { + DEEPSEEK_V4_CONTEXT_TOKENS + } else { + jcode_provider_core::context_limit_for_model(&model) + .unwrap_or(DEFAULT_CONTEXT_TOKENS) + } +} + +/// Human-readable label for context usage ratio. +pub fn usage_ratio_label(ratio: f64) -> &'static str { + match ratio { + r if r >= PREFLIGHT_EMERGENCY_THRESHOLD => "critical", + r if r >= FORCE_SUMMARY_THRESHOLD => "danger", + r if r >= HISTORY_FOLD_AGGRESSIVE_THRESHOLD => "high", + r if r >= HISTORY_FOLD_THRESHOLD => "elevated", + _ => "normal", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_message(role: Role, text: &str) -> Message { + Message { + role, + content: vec![ContentBlock::Text { + text: text.to_string(), + cache_control: None, + }], + timestamp: None, + tool_duration_ms: None, + } + } + + fn make_tool_result(content: &str) -> Message { + Message { + role: Role::User, + content: vec![ContentBlock::ToolResult { + tool_use_id: "test-id".to_string(), + content: content.to_string(), + is_error: None, + }], + timestamp: None, + tool_duration_ms: None, + } + } + + #[test] + fn test_preflight_normal() { + let msgs = vec![make_message(Role::User, "Hello")]; + let decision = preflight_check(&msgs, &[], "deepseek-v4-flash"); + assert!(!decision.needs_action); + // Overhead is 1k tokens; single word is negligible + assert!(decision.estimate_tokens < 2000, "estimate was {}", decision.estimate_tokens); + assert_eq!(decision.ctx_max, DEEPSEEK_V4_CONTEXT_TOKENS); + } + + #[test] + fn test_preflight_emergency() { + // Create enough messages to push us past the emergency threshold + // Need > 95% of 1M tokens => > 950k tokens => > 3.8M chars total + let chars_per_msg = DEEPSEEK_V4_CONTEXT_TOKENS / 2; // ~500k chars each + let mut msgs = Vec::new(); + for _ in 0..10 { + msgs.push(make_message( + Role::User, + &"x".repeat(chars_per_msg), + )); + } + let decision = preflight_check(&msgs, &[], "deepseek-v4-flash"); + assert!(decision.needs_action, "expected emergency but ratio was {:.3}", decision.ratio); + assert!(decision.ratio > PREFLIGHT_EMERGENCY_THRESHOLD); + } + + #[test] + fn test_truncate_tool_results() { + let long_content = "x".repeat(TURN_END_RESULT_CAP_CHARS + 1000); + let msgs = vec![make_tool_result(&long_content)]; + let (truncated, truncate_count) = truncate_tool_results_for_api(&msgs); + assert!(truncate_count > 0); + match &truncated[0].content[0] { + ContentBlock::ToolResult { content, .. } => { + assert!(content.len() <= TURN_END_RESULT_CAP_CHARS + 200); + } + _ => panic!("expected ToolResult"), + } + } + + #[test] + fn test_truncate_skips_short_results() { + let short_content = "short result"; + let msgs = vec![make_tool_result(short_content)]; + let (truncated, truncate_count) = truncate_tool_results_for_api(&msgs); + assert_eq!(truncate_count, 0); + match &truncated[0].content[0] { + ContentBlock::ToolResult { content, .. } => { + assert_eq!(content, short_content); + } + _ => panic!("expected ToolResult"), + } + } + + #[test] + fn test_decide_after_usage() { + let ctx_max = DEEPSEEK_V4_CONTEXT_TOKENS; + + assert!( + matches!(decide_after_usage((ctx_max as f64 * 0.3) as usize, "deepseek-v4-flash", false), PostUsageAction::None) + ); + assert!( + matches!(decide_after_usage((ctx_max as f64 * 0.55) as usize, "deepseek-v4-flash", false), PostUsageAction::Fold { .. }) + ); + assert!( + matches!(decide_after_usage((ctx_max as f64 * 0.75) as usize, "deepseek-v4-flash", false), PostUsageAction::Fold { aggressive: true, .. }) + ); + assert!( + matches!(decide_after_usage((ctx_max as f64 * 0.85) as usize, "deepseek-v4-flash", false), PostUsageAction::ExitWithSummary) + ); + } + + #[test] + fn test_context_tokens_for_model() { + assert_eq!(context_tokens_for_model("deepseek-v4-flash"), DEEPSEEK_V4_CONTEXT_TOKENS); + assert_eq!(context_tokens_for_model("deepseek-v4-pro"), DEEPSEEK_V4_CONTEXT_TOKENS); + assert_eq!(context_tokens_for_model("gpt-4"), DEFAULT_CONTEXT_TOKENS); + } +}