From b647005cb31020f3ed2b1686755397aa70334b9e Mon Sep 17 00:00:00 2001 From: matiaspalmac Date: Wed, 29 Apr 2026 01:10:33 -0400 Subject: [PATCH] feat: cloud AI streaming for OpenAI / Anthropic / Gemini / OpenRouter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #284 shipped cloud chat as a single-shot non-streaming bridge — every reply landed atomically once the request finished. That made the cloud branch feel slower than Ollama, where the chat panel renders tokens as they arrive. This wires SSE streaming through a new `cloud_proxy_stream` Tauri command so the UX matches. Rust - New `cloud_proxy_stream` and `cloud_proxy_cancel` commands behind the existing `validate_cloud_proxy_request` policy. The command spawns a detached tokio task, registers a oneshot cancel handle keyed by the caller-supplied `streamId`, and emits `cloud-stream` events as `{ kind: "data" | "done" | "error", data?, error? }`. - New `split_sse_events` parser splits the response buffer on `\n\n` / `\r\n\r\n`, joins multi-`data:` lines per event, and drops `event:` / `id:` / `retry:` / comment lines. The Rust side stays provider-agnostic; the per-provider JSON shape is parsed in TS. - The `data: [DONE]` sentinel that OpenAI / OpenRouter use is collapsed to a structured `done` event so callers don't have to special-case it. Anthropic and Gemini close the connection silently on completion; the task synthesises a `done` event on EOF so the awaiter resolves uniformly. - Total stream size capped by the same `CLOUD_PROXY_MAX_BODY` limit (16 MiB) that `cloud_proxy` uses, and per-request timeout uses `CLOUD_PROXY_TIMEOUT_SECS` (60 s). Cancellation tears the task down via the oneshot. - 10 unit tests for `split_sse_events` covering single events, multi-line data joining, CRLF separators, partial-event buffering, the [DONE] sentinel, and the SSE-spec single-leading-space rule. TypeScript - New `streamCloudChat(req, onDelta)` mirroring `streamOllamaChat`'s shape: subscribes to `cloud-stream`, filters by `streamId`, parses the provider-specific delta, and resolves with the full text on done. - Provider-specific delta extractors split into `extractStreamDelta(provider, raw)`: - OpenAI / OpenRouter: `choices[0].delta.content` - Anthropic: `content_block_delta` events with `delta.type === "text_delta"` only — `message_start`, `content_block_start`, `ping`, `message_stop`, `input_json_delta` (future tool deltas) are ignored - Gemini: `candidates[*].content.parts[*].text` joined - The matching `extractFullResponse` is exported and unit-tested so the non-streaming `cloudChat` shares the same parsing surface. - The provider URL / headers / body builder is consolidated into `buildProviderRequest(req, stream)`. Gemini's URL flips between `:generateContent` and `:streamGenerateContent?alt=sse`; the others just toggle `stream` in the body. - 12 unit tests covering each provider's streaming and full-response branches, including the role-only chunk that OpenAI sends as the first delta and Anthropic's housekeeping events. UI wiring - `AiChatComponent` cloud branch now uses `streamCloudChat` with the same `placeholderPushed` / `appendToLastAiMessage` pattern the Ollama branch uses. If the stream fails before any content arrives we still render a clean error bubble; if it fails mid-stream we append the error inline so the user sees how far the model got. - Tools / agent mode for cloud providers stays out of scope. Each provider has its own tool-call shape (`tool_calls` for OpenAI, `tool_use` blocks for Anthropic, `functionDeclarations` for Gemini) and that lands in the next PR. Verification - pnpm tsc --noEmit / lint --max-warnings 0 clean. - pnpm vitest run: 125 / 125 (was 113 — added 12 provider tests). - pnpm build: /, /_not-found, /floating prerender clean. - cargo build / clippy --lib -- -D warnings / fmt --check clean. - cargo test --lib: 112 / 112 (was 102 — added 10 SSE-parser tests). --- apps/desktop/src-tauri/src/lib.rs | 406 +++++++++++++++ .../builtin.ai-assistant/AiChatComponent.tsx | 61 ++- apps/desktop/src/api/providers/client.test.ts | 113 ++++ apps/desktop/src/api/providers/client.ts | 482 ++++++++++++------ apps/desktop/src/api/tauri.ts | 11 + 5 files changed, 900 insertions(+), 173 deletions(-) create mode 100644 apps/desktop/src/api/providers/client.test.ts diff --git a/apps/desktop/src-tauri/src/lib.rs b/apps/desktop/src-tauri/src/lib.rs index a63898e3..666833bf 100644 --- a/apps/desktop/src-tauri/src/lib.rs +++ b/apps/desktop/src-tauri/src/lib.rs @@ -2067,6 +2067,313 @@ async fn ollama_proxy_cancel( Ok(()) } +// ============================================================ +// Cloud streaming (SSE) +// ============================================================ + +/// Same shape as `OllamaStreams` but for the cloud-AI bridge. Each open +/// `cloud_proxy_stream` task registers a oneshot sender keyed by the +/// caller-supplied `streamId` so a follow-up `cloud_proxy_cancel` can +/// terminate it from the renderer. +type CloudStreams = Arc>>>; + +fn new_cloud_streams() -> CloudStreams { + Arc::new(Mutex::new(HashMap::new())) +} + +/// Payload for the `cloud-stream` event. Mirror of `OllamaStreamPayload`'s +/// camelCase wire form. Cloud providers each speak a slightly different +/// JSON shape inside the SSE `data:` lines, so we forward the raw payload +/// (joined by `\n` for multi-line events) and let the per-provider TS +/// adapter extract the delta. Keeping the Rust side provider-agnostic +/// keeps `cloud_proxy_stream` a thin transport that the visual editors +/// and future agent-mode loop can reuse without each gaining its own +/// protocol switch. +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct CloudStreamPayload { + stream_id: String, + kind: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +/// Splits an SSE stream buffer into complete events. Per the SSE spec, +/// events are separated by a blank line — `\n\n` or `\r\n\r\n`. Within an +/// event the `data:` lines are concatenated by `\n` (also per spec). We +/// drop everything else (`event:`, `id:`, `retry:`, comments) because none +/// of the four cloud providers in our allow-list use them to convey the +/// generation delta. +/// +/// Pure function so it can be unit tested without standing up a real HTTP +/// stream. Mirrors the Ollama side's `split_ndjson_lines` design. +fn split_sse_events(buffer: &mut String) -> Vec { + let mut events = Vec::new(); + loop { + let crlf = buffer.find("\r\n\r\n"); + let lf = buffer.find("\n\n"); + let (idx, sep_len) = match (crlf, lf) { + (Some(c), Some(l)) if c <= l => (c, 4), + (Some(c), None) => (c, 4), + (_, Some(l)) => (l, 2), + (None, None) => break, + }; + let event = buffer[..idx].to_string(); + buffer.drain(..idx + sep_len); + + let mut data_lines: Vec = Vec::new(); + for raw_line in event.split('\n') { + let line = raw_line.strip_suffix('\r').unwrap_or(raw_line); + if let Some(rest) = line.strip_prefix("data:") { + // Per SSE: a single space after the colon is part of the + // separator and must be stripped; further whitespace is + // payload. `strip_prefix(" ")` is `Some(rest)` exactly when + // a single leading space exists, so this gives the right + // semantics for both `data:` and `data: ` framings. + let payload = rest.strip_prefix(' ').unwrap_or(rest); + data_lines.push(payload.to_string()); + } + } + if !data_lines.is_empty() { + events.push(data_lines.join("\n")); + } + } + events +} + +/// Streaming sibling of `cloud_proxy`. Subjects every request to the same +/// per-host method + path + header allow-list, then pumps SSE events back +/// to the renderer through `cloud-stream` tagged with the caller-supplied +/// `streamId`. The Rust side is provider-agnostic — each `data:` payload +/// is forwarded verbatim and the TS adapter parses the OpenAI / Anthropic +/// / Gemini-specific JSON shape. +/// +/// Returns `Ok(())` as soon as the cancel handle is registered. The actual +/// transport runs on a detached tokio task so the bridge isn't held open +/// for the duration of a multi-second generation. Transport / size / +/// non-2xx failures are surfaced as `cloud-stream` `error` events. +#[tauri::command] +async fn cloud_proxy_stream( + app: tauri::AppHandle, + streams: tauri::State<'_, CloudStreams>, + stream_id: String, + method: String, + url: String, + headers: Option>, + body: Option, +) -> Result<(), String> { + use tauri::Emitter; + + let policy = validate_cloud_proxy_request(&url, &method)?; + + // Register the cancel handle BEFORE spawning so a fast follow-up + // `cloud_proxy_cancel` arriving between `spawn` and the first poll + // still has an entry to fire. + let (cancel_tx, mut cancel_rx) = tokio::sync::oneshot::channel::<()>(); + { + let mut guard = streams + .lock() + .map_err(|e| format!("cloud streams mutex poisoned: {}", e))?; + guard.insert(stream_id.clone(), cancel_tx); + } + + let app_for_task = app.clone(); + let streams_for_task = (*streams).clone(); + let stream_id_for_task = stream_id.clone(); + let allowed_headers = policy.allowed_headers; + + tauri::async_runtime::spawn(async move { + let client = http::shared_client(); + let mut request = match method.as_str() { + "POST" => client.post(&url), + "GET" => client.get(&url), + _ => unreachable!("method validated above"), + }; + request = request.timeout(std::time::Duration::from_secs(CLOUD_PROXY_TIMEOUT_SECS)); + + if let Some(pairs) = headers { + for (name, value) in pairs { + let lower = name.to_ascii_lowercase(); + if !allowed_headers.iter().any(|h| *h == lower) { + continue; + } + request = request.header(name, value); + } + } + if let Some(json_body) = body { + request = request.json(&json_body); + } + + let response = match request.send().await { + Ok(r) => r, + Err(e) => { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!("request failed: {}", e)), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + }; + + let status = response.status().as_u16(); + if !response.status().is_success() { + // Drain the body once and surface as a single error event so the + // frontend can show a useful message instead of "stream failed". + let body_text = response.text().await.unwrap_or_default(); + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!("HTTP {}: {}", status, body_text)), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + + let mut response = response; + let mut buffer = String::new(); + let mut total_bytes: usize = 0; + let mut sent_done = false; + + loop { + if cancel_rx.try_recv().is_ok() + || matches!( + cancel_rx.try_recv(), + Err(tokio::sync::oneshot::error::TryRecvError::Closed) + ) + { + break; + } + + let chunk = match response.chunk().await { + Ok(Some(b)) => b, + Ok(None) => break, // EOF + Err(e) => { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!("chunk read failed: {}", e)), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + }; + + total_bytes = total_bytes.saturating_add(chunk.len()); + if total_bytes > CLOUD_PROXY_MAX_BODY { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!( + "Cloud stream exceeded the {} MiB cap", + CLOUD_PROXY_MAX_BODY / (1024 * 1024) + )), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + + buffer.push_str(&String::from_utf8_lossy(&chunk)); + let events = split_sse_events(&mut buffer); + for data in events { + // OpenAI / OpenRouter terminate with `data: [DONE]`. We + // collapse that to a structured `done` event so the TS + // side doesn't have to special-case the sentinel — and + // future providers that adopt the same convention work + // for free. + if data == "[DONE]" { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "done", + data: None, + error: None, + }, + ); + sent_done = true; + break; + } + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "data", + data: Some(data), + error: None, + }, + ); + } + + if sent_done { + break; + } + } + + if !sent_done { + // Anthropic + Gemini close the connection on completion instead + // of sending a terminal sentinel. Synthesise a `done` event so + // the frontend's awaiter resolves either way. + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "done", + data: None, + error: None, + }, + ); + } + + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + }); + + Ok(()) +} + +fn remove_cloud_stream(streams: &CloudStreams, stream_id: &str) { + if let Ok(mut guard) = streams.lock() { + guard.remove(stream_id); + } +} + +/// Cancels an in-flight `cloud_proxy_stream` task. Idempotent. +#[tauri::command] +async fn cloud_proxy_cancel( + streams: tauri::State<'_, CloudStreams>, + stream_id: String, +) -> Result<(), String> { + let sender = { + let mut guard = streams + .lock() + .map_err(|e| format!("cloud streams mutex poisoned: {}", e))?; + guard.remove(&stream_id) + }; + if let Some(tx) = sender { + let _ = tx.send(()); + } + Ok(()) +} + #[tauri::command] fn create_directory( path: String, @@ -2232,6 +2539,7 @@ pub fn run() { ) .manage::(new_pty_sessions()) .manage::(new_ollama_streams()) + .manage::(new_cloud_streams()) .manage(Arc::new(Mutex::new(SystemState { sys: System::new_all(), }))) @@ -2293,6 +2601,8 @@ pub fn run() { ollama_proxy_stream, ollama_proxy_cancel, cloud_proxy, + cloud_proxy_stream, + cloud_proxy_cancel, create_directory, reveal_path, delete_path, @@ -2702,3 +3012,99 @@ mod split_ndjson_lines_tests { assert!(buf.is_empty()); } } + +#[cfg(test)] +mod split_sse_events_tests { + use super::split_sse_events; + + #[test] + fn parses_one_complete_event() { + let mut buf = String::from("data: {\"text\":\"hi\"}\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("{\"text\":\"hi\"}")]); + assert!( + buf.is_empty(), + "complete event should leave empty remainder" + ); + } + + #[test] + fn joins_multiple_data_lines_within_an_event_with_newline() { + let mut buf = String::from("data: line one\ndata: line two\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("line one\nline two")]); + } + + #[test] + fn skips_event_id_retry_and_comment_lines() { + let mut buf = + String::from("event: ping\nid: 42\nretry: 1000\n: this is a comment\ndata: hello\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("hello")]); + } + + #[test] + fn drops_events_without_a_data_line() { + let mut buf = String::from("event: ping\n\ndata: real\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("real")]); + } + + #[test] + fn holds_partial_event_in_buffer_until_next_chunk() { + let mut buf = String::new(); + buf.push_str("data: {\"a\":1}"); + let first = split_sse_events(&mut buf); + assert!(first.is_empty(), "no terminator yet"); + buf.push_str("\n\n"); + let second = split_sse_events(&mut buf); + assert_eq!(second, vec![String::from("{\"a\":1}")]); + } + + #[test] + fn handles_crlf_separators() { + let mut buf = String::from("data: foo\r\n\r\ndata: bar\r\n\r\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("foo"), String::from("bar")]); + assert!(buf.is_empty()); + } + + #[test] + fn passes_through_done_sentinel_unchanged() { + let mut buf = String::from("data: [DONE]\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("[DONE]")]); + } + + #[test] + fn empty_input_is_a_noop() { + let mut buf = String::new(); + let out = split_sse_events(&mut buf); + assert!(out.is_empty()); + assert!(buf.is_empty()); + } + + #[test] + fn handles_back_to_back_events_in_one_chunk() { + let mut buf = String::from("data: {\"a\":1}\n\ndata: {\"a\":2}\n\ndata: [DONE]\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!( + out, + vec![ + String::from("{\"a\":1}"), + String::from("{\"a\":2}"), + String::from("[DONE]"), + ] + ); + } + + #[test] + fn preserves_a_single_leading_space_after_data_colon() { + // Per SSE spec: a single space after `data:` is part of the + // separator, not the payload. Anything beyond that single space is + // payload — so `data: hello` carries ` hello` (one leading space). + let mut buf = String::from("data: hello\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from(" hello")]); + } +} diff --git a/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx b/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx index 6156c6d8..2a9a92e8 100644 --- a/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx +++ b/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx @@ -24,7 +24,7 @@ import { ToolApprovalPanel } from "./ToolApprovalPanel"; import { classifyToolError, formatToolError, failureKey } from "./toolErrors"; import { extractPlan } from "./planExtractor"; import { streamOllamaChat, type OllamaStreamFinalMessage } from "./ollamaStream"; -import { cloudChat, keyForProvider, type ChatMessage as ProviderChatMessage } from "@/api/providers/client"; +import { streamCloudChat, keyForProvider, type ChatMessage as ProviderChatMessage } from "@/api/providers/client"; import { PROVIDERS, PROVIDER_IDS } from "@/api/providers/registry"; type ToolArgs = Record; @@ -538,10 +538,10 @@ const AiChatComponent: React.FC = () => { abortControllerRef.current = controller; // Cloud-provider branch (issue #267). Routes the request through the - // generic `cloud_proxy` instead of Ollama's bridge. Tools / agent / - // streaming are intentionally not wired for cloud yet — every provider - // has a different SSE envelope and tool-call format. Single-shot - // chat works for all four providers we ship today. + // generic `cloud_proxy_stream` instead of Ollama's bridge. Tools / + // agent are not wired for cloud yet (each provider has its own + // tool-call shape — that lands in a follow-up); streaming text works + // for all four providers we ship today. const activeProvider = aiSettings.activeProvider ?? "ollama"; if (activeProvider !== "ollama") { try { @@ -583,26 +583,47 @@ const AiChatComponent: React.FC = () => { })), { role: "user", content: userMessage }, ]; - const result = await cloudChat({ - provider: activeProvider, - apiKey: cloudKey, - model: modelToUse, - messages: cloudHistory, - temperature: aiSettings.temperature, - maxTokens: aiSettings.maxTokens, - signal: controller.signal, - }); + + // Lazily create the assistant bubble on the first delta so we + // don't leave an empty placeholder behind if the request fails + // before any content arrives. Same pattern the Ollama branch + // uses below. + let placeholderPushed = false; + const pushPlaceholderOnce = () => { + if (placeholderPushed) return; + addMessageToSession(activeSessionId, { role: "ai", text: "" }); + placeholderPushed = true; + }; + + const result = await streamCloudChat( + { + provider: activeProvider, + apiKey: cloudKey, + model: modelToUse, + messages: cloudHistory, + temperature: aiSettings.temperature, + maxTokens: aiSettings.maxTokens, + signal: controller.signal, + }, + (delta) => { + pushPlaceholderOnce(); + appendToLastAiMessage(activeSessionId, delta); + }, + ); if (controller.signal.aborted) return; - if (!result.ok) { + if (!result.ok && !placeholderPushed) { addMessageToSession(activeSessionId, { role: "ai", text: `Error: ${result.error || "cloud provider returned no content"}`, }); - } else { - addMessageToSession(activeSessionId, { - role: "ai", - text: result.text, - }); + } else if (!result.ok) { + // Stream started, then failed mid-flight — append the error + // inline so the user can see how far the model got before the + // failure. The bubble already exists so we don't push a new one. + appendToLastAiMessage( + activeSessionId, + `\n\n[Error: ${result.error || "stream interrupted"}]`, + ); } } catch (err) { if (!controller.signal.aborted) { diff --git a/apps/desktop/src/api/providers/client.test.ts b/apps/desktop/src/api/providers/client.test.ts new file mode 100644 index 00000000..284082c8 --- /dev/null +++ b/apps/desktop/src/api/providers/client.test.ts @@ -0,0 +1,113 @@ +import { describe, expect, it } from "vitest"; +import { extractStreamDelta, extractFullResponse } from "./client"; + +describe("extractStreamDelta", () => { + it("returns empty string for the [DONE] sentinel and non-JSON input", () => { + for (const provider of ["openai", "anthropic", "gemini", "openrouter"] as const) { + expect(extractStreamDelta(provider, "[DONE]")).toBe(""); + expect(extractStreamDelta(provider, "")).toBe(""); + expect(extractStreamDelta(provider, "not-json")).toBe(""); + } + }); + + it("extracts choices[0].delta.content for OpenAI / OpenRouter", () => { + const data = JSON.stringify({ + choices: [{ delta: { content: "hello" } }], + }); + expect(extractStreamDelta("openai", data)).toBe("hello"); + expect(extractStreamDelta("openrouter", data)).toBe("hello"); + }); + + it("returns empty for OpenAI events without delta.content (role-only chunks)", () => { + const roleOnly = JSON.stringify({ + choices: [{ delta: { role: "assistant" } }], + }); + expect(extractStreamDelta("openai", roleOnly)).toBe(""); + }); + + it("extracts content_block_delta.text_delta for Anthropic", () => { + const data = JSON.stringify({ + type: "content_block_delta", + delta: { type: "text_delta", text: "world" }, + }); + expect(extractStreamDelta("anthropic", data)).toBe("world"); + }); + + it("ignores Anthropic housekeeping events (message_start, ping, message_stop)", () => { + expect( + extractStreamDelta("anthropic", JSON.stringify({ type: "ping" })), + ).toBe(""); + expect( + extractStreamDelta("anthropic", JSON.stringify({ type: "message_start" })), + ).toBe(""); + expect( + extractStreamDelta("anthropic", JSON.stringify({ type: "message_stop" })), + ).toBe(""); + }); + + it("ignores Anthropic content_block_delta with non-text_delta variants", () => { + // Future-proof: tool_use deltas show up under content_block_delta too + // and must NOT be appended as if they were chat text. + const toolUse = JSON.stringify({ + type: "content_block_delta", + delta: { type: "input_json_delta", partial_json: "{...}" }, + }); + expect(extractStreamDelta("anthropic", toolUse)).toBe(""); + }); + + it("extracts candidates[0].content.parts[*].text for Gemini", () => { + const data = JSON.stringify({ + candidates: [ + { content: { parts: [{ text: "foo" }, { text: " bar" }] } }, + ], + }); + expect(extractStreamDelta("gemini", data)).toBe("foo bar"); + }); + + it("returns empty for Gemini events without text parts", () => { + expect( + extractStreamDelta( + "gemini", + JSON.stringify({ candidates: [{ finishReason: "STOP" }] }), + ), + ).toBe(""); + }); +}); + +describe("extractFullResponse", () => { + it("extracts choices[0].message.content for OpenAI / OpenRouter", () => { + const body = JSON.stringify({ + choices: [{ message: { content: "complete reply" } }], + }); + expect(extractFullResponse("openai", body)).toBe("complete reply"); + expect(extractFullResponse("openrouter", body)).toBe("complete reply"); + }); + + it("joins all content[*].text blocks for Anthropic, ignoring tool_use", () => { + const body = JSON.stringify({ + content: [ + { type: "text", text: "alpha " }, + { type: "tool_use", id: "x", name: "y", input: {} }, + { type: "text", text: "beta" }, + ], + }); + expect(extractFullResponse("anthropic", body)).toBe("alpha beta"); + }); + + it("joins all candidates[*].content.parts[*].text blocks for Gemini", () => { + const body = JSON.stringify({ + candidates: [ + { content: { parts: [{ text: "one " }, { text: "two" }] } }, + { content: { parts: [{ text: " three" }] } }, + ], + }); + expect(extractFullResponse("gemini", body)).toBe("one two three"); + }); + + it("returns empty string for non-JSON or empty bodies", () => { + for (const provider of ["openai", "anthropic", "gemini", "openrouter"] as const) { + expect(extractFullResponse(provider, "")).toBe(""); + expect(extractFullResponse(provider, "garbage")).toBe(""); + } + }); +}); diff --git a/apps/desktop/src/api/providers/client.ts b/apps/desktop/src/api/providers/client.ts index d53aeb31..ffc9f3a7 100644 --- a/apps/desktop/src/api/providers/client.ts +++ b/apps/desktop/src/api/providers/client.ts @@ -1,5 +1,7 @@ "use client"; +import { listen, type UnlistenFn } from "@tauri-apps/api/event"; +import { invoke as tauriInvoke } from "@tauri-apps/api/core"; import { safeInvoke as invoke } from "@/api/tauri"; import type { ProviderId, ProviderKeys } from "@/context/SettingsContext"; import { logger } from "@/lib/logger"; @@ -25,14 +27,125 @@ export interface CloudChatResult { error?: string; } +interface ProviderRequest { + url: string; + headers: Array<[string, string]>; + body: unknown; +} + /** - * Single-shot non-streaming chat for cloud providers. Streaming for - * cloud is a follow-up — every provider here uses a different SSE - * envelope and the Rust side currently only has a streaming bridge for - * Ollama. Returning the whole reply once the request resolves lets us - * wire all four providers without a per-provider streaming bridge. + * Build the URL / headers / body for a provider chat call. The `stream` + * flag toggles the streaming endpoint (Gemini changes URL; the others + * just set `stream: true` in the body) so the same builder backs both + * `cloudChat` and `streamCloudChat`. */ -export async function cloudChat(req: CloudChatRequest): Promise { +function buildProviderRequest( + req: CloudChatRequest, + stream: boolean, +): ProviderRequest { + switch (req.provider) { + case "openai": + return { + url: "https://api.openai.com/v1/chat/completions", + headers: [ + ["Authorization", `Bearer ${req.apiKey}`], + ["Content-Type", "application/json"], + ], + body: { + model: req.model, + messages: req.messages, + temperature: req.temperature ?? 0.7, + max_tokens: req.maxTokens ?? 2048, + stream, + }, + }; + case "openrouter": + return { + url: "https://openrouter.ai/api/v1/chat/completions", + headers: [ + ["Authorization", `Bearer ${req.apiKey}`], + ["Content-Type", "application/json"], + ["HTTP-Referer", "https://github.com/TrixtyAI/ide"], + ["X-Title", "Trixty IDE"], + ], + body: { + model: req.model, + messages: req.messages, + temperature: req.temperature ?? 0.7, + max_tokens: req.maxTokens ?? 2048, + stream, + }, + }; + case "anthropic": { + // Anthropic separates the system prompt from the messages array. + const systemMessages = req.messages.filter((m) => m.role === "system"); + const conversation = req.messages.filter((m) => m.role !== "system"); + const system = systemMessages + .map((m) => m.content) + .join("\n\n") + .trim(); + return { + url: "https://api.anthropic.com/v1/messages", + headers: [ + ["x-api-key", req.apiKey], + ["anthropic-version", "2023-06-01"], + ["Content-Type", "application/json"], + ], + body: { + model: req.model, + max_tokens: req.maxTokens ?? 2048, + temperature: req.temperature ?? 0.7, + system: system || undefined, + messages: conversation.map((m) => ({ + role: m.role === "assistant" ? "assistant" : "user", + content: m.content, + })), + stream, + }, + }; + } + case "gemini": { + // Gemini supports the API key either as a query param or the + // `x-goog-api-key` header. We use the header form so the key never + // shows up in URL logs (Tauri's `e.to_string()` on a transport + // failure echoes the URL, OS-level proxies log query strings, etc.). + const path = stream ? "streamGenerateContent?alt=sse" : "generateContent"; + const url = `https://generativelanguage.googleapis.com/v1beta/models/${encodeURIComponent( + req.model, + )}:${path}`; + const systemMessages = req.messages.filter((m) => m.role === "system"); + const conversation = req.messages.filter((m) => m.role !== "system"); + const systemInstruction = systemMessages.length + ? { + role: "user", + parts: [ + { text: systemMessages.map((m) => m.content).join("\n\n") }, + ], + } + : undefined; + return { + url, + headers: [ + ["x-goog-api-key", req.apiKey], + ["Content-Type", "application/json"], + ], + body: { + contents: conversation.map((m) => ({ + role: m.role === "assistant" ? "model" : "user", + parts: [{ text: m.content }], + })), + ...(systemInstruction ? { systemInstruction } : {}), + generationConfig: { + temperature: req.temperature ?? 0.7, + maxOutputTokens: req.maxTokens ?? 2048, + }, + }, + }; + } + } +} + +function validateRequest(req: CloudChatRequest): CloudChatResult | null { if (req.signal?.aborted) return { ok: false, text: "", error: "aborted" }; if (!req.apiKey) { return { @@ -48,24 +161,40 @@ export async function cloudChat(req: CloudChatRequest): Promise error: `No model selected for ${req.provider}`, }; } + return null; +} + +/** + * Single-shot non-streaming chat for cloud providers. Kept for callers + * that need the full reply atomically (e.g. background tool-summary + * jobs). For interactive chat use `streamCloudChat`, which threads + * tokens into the UI as they arrive. + */ +export async function cloudChat(req: CloudChatRequest): Promise { + const invalid = validateRequest(req); + if (invalid) return invalid; try { - switch (req.provider) { - case "openai": - return await chatOpenAICompatible( - "https://api.openai.com/v1/chat/completions", - req, - ); - case "openrouter": - return await chatOpenAICompatible( - "https://openrouter.ai/api/v1/chat/completions", - req, - ); - case "anthropic": - return await chatAnthropic(req); - case "gemini": - return await chatGemini(req); + const config = buildProviderRequest(req, false); + const result = await invoke( + "cloud_proxy", + { + method: "POST", + url: config.url, + headers: config.headers, + body: config.body, + }, + { silent: true }, + ); + if (result.status < 200 || result.status >= 300) { + return { + ok: false, + text: "", + error: `${req.provider} HTTP ${result.status}: ${truncate(result.body, 240)}`, + }; } + const text = extractFullResponse(req.provider, result.body); + return { ok: text.length > 0, text }; } catch (err) { if (req.signal?.aborted) return { ok: false, text: "", error: "aborted" }; logger.warn(`[providers/${req.provider}] chat failed:`, err); @@ -74,153 +203,200 @@ export async function cloudChat(req: CloudChatRequest): Promise } /** - * OpenAI / OpenRouter share the `/v1/chat/completions` shape. OpenRouter - * also requires the standard `Authorization: Bearer KEY` header — the - * `HTTP-Referer` and `X-Title` headers are optional metadata that - * surface the app name in OpenRouter's dashboards. + * Streaming chat for cloud providers. Emits each token to `onDelta` as + * it arrives and resolves with the full concatenated text on completion. + * The Rust `cloud_proxy_stream` command pumps SSE events back through + * the `cloud-stream` Tauri event keyed by a UUID `streamId`. + * + * Cancellation: if `req.signal` aborts mid-stream the helper fires + * `cloud_proxy_cancel` so the tokio task tears down before more chunks + * arrive, then re-throws an `AbortError` so callers can branch the same + * way they do for `streamOllamaChat`. */ -async function chatOpenAICompatible( - url: string, +export async function streamCloudChat( req: CloudChatRequest, + onDelta: (text: string) => void, ): Promise { - const headers: Array<[string, string]> = [ - ["Authorization", `Bearer ${req.apiKey}`], - ["Content-Type", "application/json"], - ]; - if (req.provider === "openrouter") { - headers.push(["HTTP-Referer", "https://github.com/TrixtyAI/ide"]); - headers.push(["X-Title", "Trixty IDE"]); + const invalid = validateRequest(req); + if (invalid) return invalid; + + const config = buildProviderRequest(req, true); + const streamId = crypto.randomUUID(); + + let fullText = ""; + let errorText: string | undefined; + let unlisten: UnlistenFn | undefined; + let aborted = false; + + const settled = new Promise((resolve, reject) => { + listen("cloud-stream", (event) => { + const payload = event.payload; + if (payload.streamId !== streamId) return; + if (payload.kind === "data") { + const delta = extractStreamDelta(req.provider, payload.data ?? ""); + if (delta) { + fullText += delta; + onDelta(delta); + } + return; + } + if (payload.kind === "done") { + resolve(); + return; + } + if (payload.kind === "error") { + errorText = payload.error ?? "Unknown streaming error"; + reject(new Error(errorText)); + } + }).then((u) => { + unlisten = u; + }); + }); + + const onAbort = () => { + aborted = true; + tauriInvoke("cloud_proxy_cancel", { streamId }).catch((err) => { + logger.debug("[providers/cloud] cancel failed:", err); + }); + }; + if (req.signal?.aborted) { + onAbort(); + } else if (req.signal) { + req.signal.addEventListener("abort", onAbort, { once: true }); } - const result = await invoke( - "cloud_proxy", - { + + try { + await tauriInvoke("cloud_proxy_stream", { + streamId, method: "POST", - url, - headers, - body: { - model: req.model, - messages: req.messages, - temperature: req.temperature ?? 0.7, - max_tokens: req.maxTokens ?? 2048, - }, - }, - { silent: true }, - ); - if (result.status < 200 || result.status >= 300) { + url: config.url, + headers: config.headers, + body: config.body, + }); + await settled; + return { ok: fullText.length > 0, text: fullText }; + } catch (err) { + if (aborted) { + const abortError = new Error("Aborted"); + abortError.name = "AbortError"; + throw abortError; + } return { ok: false, - text: "", - error: `${req.provider} HTTP ${result.status}: ${truncate(result.body, 240)}`, + text: fullText, + error: errorText ?? (err instanceof Error ? err.message : String(err)), }; + } finally { + if (unlisten) unlisten(); + if (req.signal) req.signal.removeEventListener("abort", onAbort); } - const parsed = JSON.parse(result.body) as { - choices?: { message?: { content?: string } }[]; - }; - const text = parsed.choices?.[0]?.message?.content ?? ""; - return { ok: text.length > 0, text }; } -async function chatAnthropic(req: CloudChatRequest): Promise { - // Anthropic separates the system prompt from the messages array. Pull - // any leading system message into the dedicated `system` field. - const systemMessages = req.messages.filter((m) => m.role === "system"); - const conversation = req.messages.filter((m) => m.role !== "system"); - const system = systemMessages.map((m) => m.content).join("\n\n").trim(); - const result = await invoke( - "cloud_proxy", - { - method: "POST", - url: "https://api.anthropic.com/v1/messages", - headers: [ - ["x-api-key", req.apiKey], - ["anthropic-version", "2023-06-01"], - ["Content-Type", "application/json"], - ], - body: { - model: req.model, - max_tokens: req.maxTokens ?? 2048, - temperature: req.temperature ?? 0.7, - system: system || undefined, - messages: conversation.map((m) => ({ - role: m.role === "assistant" ? "assistant" : "user", - content: m.content, - })), - }, - }, - { silent: true }, - ); - if (result.status < 200 || result.status >= 300) { - return { - ok: false, - text: "", - error: `anthropic HTTP ${result.status}: ${truncate(result.body, 240)}`, - }; +interface CloudStreamEvent { + streamId: string; + kind: "data" | "done" | "error"; + data?: string; + error?: string; +} + +/** + * Per-provider parse of one SSE `data:` payload into the delta text to + * append. Returns `""` for keep-alive / housekeeping events (Anthropic + * `ping`, message_start / stop, etc.) so the caller can ignore them. + */ +export function extractStreamDelta( + provider: Exclude, + raw: string, +): string { + if (!raw || raw === "[DONE]") return ""; + let parsed: unknown; + try { + parsed = JSON.parse(raw); + } catch { + return ""; + } + if (typeof parsed !== "object" || parsed === null) return ""; + const obj = parsed as Record; + + switch (provider) { + case "openai": + case "openrouter": { + const choices = obj.choices as + | Array<{ delta?: { content?: string } }> + | undefined; + return choices?.[0]?.delta?.content ?? ""; + } + case "anthropic": { + if (obj.type !== "content_block_delta") return ""; + const delta = obj.delta as + | { type?: string; text?: string } + | undefined; + if (delta?.type !== "text_delta") return ""; + return delta.text ?? ""; + } + case "gemini": { + const candidates = obj.candidates as + | Array<{ content?: { parts?: Array<{ text?: string }> } }> + | undefined; + return ( + candidates + ?.flatMap((c) => c.content?.parts ?? []) + .map((p) => p.text ?? "") + .join("") ?? "" + ); + } } - const parsed = JSON.parse(result.body) as { - content?: { type: string; text?: string }[]; - }; - const text = (parsed.content ?? []) - .filter((b) => b.type === "text") - .map((b) => b.text ?? "") - .join(""); - return { ok: text.length > 0, text }; } -async function chatGemini(req: CloudChatRequest): Promise { - // Gemini supports the API key either as a query param or the - // `x-goog-api-key` header. We use the header form so the key never - // shows up in URL logs (Tauri's `e.to_string()` on a transport - // failure echoes the URL, OS-level proxies log query strings, etc.). - // Body uses Gemini's `contents` shape with `user` / `model` roles. - const url = `https://generativelanguage.googleapis.com/v1beta/models/${encodeURIComponent( - req.model, - )}:generateContent`; - const systemMessages = req.messages.filter((m) => m.role === "system"); - const conversation = req.messages.filter((m) => m.role !== "system"); - const systemInstruction = systemMessages.length - ? { - role: "user", - parts: [{ text: systemMessages.map((m) => m.content).join("\n\n") }], - } - : undefined; - const result = await invoke( - "cloud_proxy", - { - method: "POST", - url, - headers: [ - ["x-goog-api-key", req.apiKey], - ["Content-Type", "application/json"], - ], - body: { - contents: conversation.map((m) => ({ - role: m.role === "assistant" ? "model" : "user", - parts: [{ text: m.content }], - })), - ...(systemInstruction ? { systemInstruction } : {}), - generationConfig: { - temperature: req.temperature ?? 0.7, - maxOutputTokens: req.maxTokens ?? 2048, - }, - }, - }, - { silent: true }, - ); - if (result.status < 200 || result.status >= 300) { - return { - ok: false, - text: "", - error: `gemini HTTP ${result.status}: ${truncate(result.body, 240)}`, - }; +/** + * Per-provider parse of a non-streaming JSON body. Mirror of + * `extractStreamDelta`'s switch but for the full-response shape. Kept + * exported so unit tests can hit each branch without the proxy bridge. + */ +export function extractFullResponse( + provider: Exclude, + body: string, +): string { + let parsed: unknown; + try { + parsed = JSON.parse(body); + } catch { + return ""; + } + if (typeof parsed !== "object" || parsed === null) return ""; + const obj = parsed as Record; + + switch (provider) { + case "openai": + case "openrouter": { + const choices = obj.choices as + | Array<{ message?: { content?: string } }> + | undefined; + return choices?.[0]?.message?.content ?? ""; + } + case "anthropic": { + const content = obj.content as + | Array<{ type: string; text?: string }> + | undefined; + return ( + content + ?.filter((b) => b.type === "text") + .map((b) => b.text ?? "") + .join("") ?? "" + ); + } + case "gemini": { + const candidates = obj.candidates as + | Array<{ content?: { parts?: Array<{ text?: string }> } }> + | undefined; + return ( + candidates + ?.flatMap((c) => c.content?.parts ?? []) + .map((p) => p.text ?? "") + .join("") ?? "" + ); + } } - const parsed = JSON.parse(result.body) as { - candidates?: { content?: { parts?: { text?: string }[] } }[]; - }; - const text = (parsed.candidates ?? []) - .flatMap((c) => c.content?.parts ?? []) - .map((p) => p.text ?? "") - .join(""); - return { ok: text.length > 0, text }; } function truncate(s: string, n: number): string { diff --git a/apps/desktop/src/api/tauri.ts b/apps/desktop/src/api/tauri.ts index f35ad30f..f88d5133 100644 --- a/apps/desktop/src/api/tauri.ts +++ b/apps/desktop/src/api/tauri.ts @@ -107,6 +107,17 @@ export interface TauriInvokeMap { }; return: { status: number; body: string }; }; + "cloud_proxy_stream": { + args: { + streamId: string; + method: string; + url: string; + headers?: Array<[string, string]>; + body?: unknown; + }; + return: void; + }; + "cloud_proxy_cancel": { args: { streamId: string }; return: void }; "check_update": { args: { channel?: "stable" | "pre-release" }; return: { version: string; body?: string | null } | null }; "install_update": { args: { channel?: "stable" | "pre-release" }; return: void }; "spawn_pty": { args: { sessionId: string; cwd?: string; rows?: number; cols?: number }; return: void };