diff --git a/apps/desktop/src-tauri/src/lib.rs b/apps/desktop/src-tauri/src/lib.rs index a63898e3..666833bf 100644 --- a/apps/desktop/src-tauri/src/lib.rs +++ b/apps/desktop/src-tauri/src/lib.rs @@ -2067,6 +2067,313 @@ async fn ollama_proxy_cancel( Ok(()) } +// ============================================================ +// Cloud streaming (SSE) +// ============================================================ + +/// Same shape as `OllamaStreams` but for the cloud-AI bridge. Each open +/// `cloud_proxy_stream` task registers a oneshot sender keyed by the +/// caller-supplied `streamId` so a follow-up `cloud_proxy_cancel` can +/// terminate it from the renderer. +type CloudStreams = Arc>>>; + +fn new_cloud_streams() -> CloudStreams { + Arc::new(Mutex::new(HashMap::new())) +} + +/// Payload for the `cloud-stream` event. Mirror of `OllamaStreamPayload`'s +/// camelCase wire form. Cloud providers each speak a slightly different +/// JSON shape inside the SSE `data:` lines, so we forward the raw payload +/// (joined by `\n` for multi-line events) and let the per-provider TS +/// adapter extract the delta. Keeping the Rust side provider-agnostic +/// keeps `cloud_proxy_stream` a thin transport that the visual editors +/// and future agent-mode loop can reuse without each gaining its own +/// protocol switch. +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct CloudStreamPayload { + stream_id: String, + kind: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +/// Splits an SSE stream buffer into complete events. Per the SSE spec, +/// events are separated by a blank line — `\n\n` or `\r\n\r\n`. Within an +/// event the `data:` lines are concatenated by `\n` (also per spec). We +/// drop everything else (`event:`, `id:`, `retry:`, comments) because none +/// of the four cloud providers in our allow-list use them to convey the +/// generation delta. +/// +/// Pure function so it can be unit tested without standing up a real HTTP +/// stream. Mirrors the Ollama side's `split_ndjson_lines` design. +fn split_sse_events(buffer: &mut String) -> Vec { + let mut events = Vec::new(); + loop { + let crlf = buffer.find("\r\n\r\n"); + let lf = buffer.find("\n\n"); + let (idx, sep_len) = match (crlf, lf) { + (Some(c), Some(l)) if c <= l => (c, 4), + (Some(c), None) => (c, 4), + (_, Some(l)) => (l, 2), + (None, None) => break, + }; + let event = buffer[..idx].to_string(); + buffer.drain(..idx + sep_len); + + let mut data_lines: Vec = Vec::new(); + for raw_line in event.split('\n') { + let line = raw_line.strip_suffix('\r').unwrap_or(raw_line); + if let Some(rest) = line.strip_prefix("data:") { + // Per SSE: a single space after the colon is part of the + // separator and must be stripped; further whitespace is + // payload. `strip_prefix(" ")` is `Some(rest)` exactly when + // a single leading space exists, so this gives the right + // semantics for both `data:` and `data: ` framings. + let payload = rest.strip_prefix(' ').unwrap_or(rest); + data_lines.push(payload.to_string()); + } + } + if !data_lines.is_empty() { + events.push(data_lines.join("\n")); + } + } + events +} + +/// Streaming sibling of `cloud_proxy`. Subjects every request to the same +/// per-host method + path + header allow-list, then pumps SSE events back +/// to the renderer through `cloud-stream` tagged with the caller-supplied +/// `streamId`. The Rust side is provider-agnostic — each `data:` payload +/// is forwarded verbatim and the TS adapter parses the OpenAI / Anthropic +/// / Gemini-specific JSON shape. +/// +/// Returns `Ok(())` as soon as the cancel handle is registered. The actual +/// transport runs on a detached tokio task so the bridge isn't held open +/// for the duration of a multi-second generation. Transport / size / +/// non-2xx failures are surfaced as `cloud-stream` `error` events. +#[tauri::command] +async fn cloud_proxy_stream( + app: tauri::AppHandle, + streams: tauri::State<'_, CloudStreams>, + stream_id: String, + method: String, + url: String, + headers: Option>, + body: Option, +) -> Result<(), String> { + use tauri::Emitter; + + let policy = validate_cloud_proxy_request(&url, &method)?; + + // Register the cancel handle BEFORE spawning so a fast follow-up + // `cloud_proxy_cancel` arriving between `spawn` and the first poll + // still has an entry to fire. + let (cancel_tx, mut cancel_rx) = tokio::sync::oneshot::channel::<()>(); + { + let mut guard = streams + .lock() + .map_err(|e| format!("cloud streams mutex poisoned: {}", e))?; + guard.insert(stream_id.clone(), cancel_tx); + } + + let app_for_task = app.clone(); + let streams_for_task = (*streams).clone(); + let stream_id_for_task = stream_id.clone(); + let allowed_headers = policy.allowed_headers; + + tauri::async_runtime::spawn(async move { + let client = http::shared_client(); + let mut request = match method.as_str() { + "POST" => client.post(&url), + "GET" => client.get(&url), + _ => unreachable!("method validated above"), + }; + request = request.timeout(std::time::Duration::from_secs(CLOUD_PROXY_TIMEOUT_SECS)); + + if let Some(pairs) = headers { + for (name, value) in pairs { + let lower = name.to_ascii_lowercase(); + if !allowed_headers.iter().any(|h| *h == lower) { + continue; + } + request = request.header(name, value); + } + } + if let Some(json_body) = body { + request = request.json(&json_body); + } + + let response = match request.send().await { + Ok(r) => r, + Err(e) => { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!("request failed: {}", e)), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + }; + + let status = response.status().as_u16(); + if !response.status().is_success() { + // Drain the body once and surface as a single error event so the + // frontend can show a useful message instead of "stream failed". + let body_text = response.text().await.unwrap_or_default(); + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!("HTTP {}: {}", status, body_text)), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + + let mut response = response; + let mut buffer = String::new(); + let mut total_bytes: usize = 0; + let mut sent_done = false; + + loop { + if cancel_rx.try_recv().is_ok() + || matches!( + cancel_rx.try_recv(), + Err(tokio::sync::oneshot::error::TryRecvError::Closed) + ) + { + break; + } + + let chunk = match response.chunk().await { + Ok(Some(b)) => b, + Ok(None) => break, // EOF + Err(e) => { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!("chunk read failed: {}", e)), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + }; + + total_bytes = total_bytes.saturating_add(chunk.len()); + if total_bytes > CLOUD_PROXY_MAX_BODY { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "error", + data: None, + error: Some(format!( + "Cloud stream exceeded the {} MiB cap", + CLOUD_PROXY_MAX_BODY / (1024 * 1024) + )), + }, + ); + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + return; + } + + buffer.push_str(&String::from_utf8_lossy(&chunk)); + let events = split_sse_events(&mut buffer); + for data in events { + // OpenAI / OpenRouter terminate with `data: [DONE]`. We + // collapse that to a structured `done` event so the TS + // side doesn't have to special-case the sentinel — and + // future providers that adopt the same convention work + // for free. + if data == "[DONE]" { + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "done", + data: None, + error: None, + }, + ); + sent_done = true; + break; + } + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "data", + data: Some(data), + error: None, + }, + ); + } + + if sent_done { + break; + } + } + + if !sent_done { + // Anthropic + Gemini close the connection on completion instead + // of sending a terminal sentinel. Synthesise a `done` event so + // the frontend's awaiter resolves either way. + let _ = app_for_task.emit( + "cloud-stream", + CloudStreamPayload { + stream_id: stream_id_for_task.clone(), + kind: "done", + data: None, + error: None, + }, + ); + } + + remove_cloud_stream(&streams_for_task, &stream_id_for_task); + }); + + Ok(()) +} + +fn remove_cloud_stream(streams: &CloudStreams, stream_id: &str) { + if let Ok(mut guard) = streams.lock() { + guard.remove(stream_id); + } +} + +/// Cancels an in-flight `cloud_proxy_stream` task. Idempotent. +#[tauri::command] +async fn cloud_proxy_cancel( + streams: tauri::State<'_, CloudStreams>, + stream_id: String, +) -> Result<(), String> { + let sender = { + let mut guard = streams + .lock() + .map_err(|e| format!("cloud streams mutex poisoned: {}", e))?; + guard.remove(&stream_id) + }; + if let Some(tx) = sender { + let _ = tx.send(()); + } + Ok(()) +} + #[tauri::command] fn create_directory( path: String, @@ -2232,6 +2539,7 @@ pub fn run() { ) .manage::(new_pty_sessions()) .manage::(new_ollama_streams()) + .manage::(new_cloud_streams()) .manage(Arc::new(Mutex::new(SystemState { sys: System::new_all(), }))) @@ -2293,6 +2601,8 @@ pub fn run() { ollama_proxy_stream, ollama_proxy_cancel, cloud_proxy, + cloud_proxy_stream, + cloud_proxy_cancel, create_directory, reveal_path, delete_path, @@ -2702,3 +3012,99 @@ mod split_ndjson_lines_tests { assert!(buf.is_empty()); } } + +#[cfg(test)] +mod split_sse_events_tests { + use super::split_sse_events; + + #[test] + fn parses_one_complete_event() { + let mut buf = String::from("data: {\"text\":\"hi\"}\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("{\"text\":\"hi\"}")]); + assert!( + buf.is_empty(), + "complete event should leave empty remainder" + ); + } + + #[test] + fn joins_multiple_data_lines_within_an_event_with_newline() { + let mut buf = String::from("data: line one\ndata: line two\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("line one\nline two")]); + } + + #[test] + fn skips_event_id_retry_and_comment_lines() { + let mut buf = + String::from("event: ping\nid: 42\nretry: 1000\n: this is a comment\ndata: hello\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("hello")]); + } + + #[test] + fn drops_events_without_a_data_line() { + let mut buf = String::from("event: ping\n\ndata: real\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("real")]); + } + + #[test] + fn holds_partial_event_in_buffer_until_next_chunk() { + let mut buf = String::new(); + buf.push_str("data: {\"a\":1}"); + let first = split_sse_events(&mut buf); + assert!(first.is_empty(), "no terminator yet"); + buf.push_str("\n\n"); + let second = split_sse_events(&mut buf); + assert_eq!(second, vec![String::from("{\"a\":1}")]); + } + + #[test] + fn handles_crlf_separators() { + let mut buf = String::from("data: foo\r\n\r\ndata: bar\r\n\r\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("foo"), String::from("bar")]); + assert!(buf.is_empty()); + } + + #[test] + fn passes_through_done_sentinel_unchanged() { + let mut buf = String::from("data: [DONE]\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from("[DONE]")]); + } + + #[test] + fn empty_input_is_a_noop() { + let mut buf = String::new(); + let out = split_sse_events(&mut buf); + assert!(out.is_empty()); + assert!(buf.is_empty()); + } + + #[test] + fn handles_back_to_back_events_in_one_chunk() { + let mut buf = String::from("data: {\"a\":1}\n\ndata: {\"a\":2}\n\ndata: [DONE]\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!( + out, + vec![ + String::from("{\"a\":1}"), + String::from("{\"a\":2}"), + String::from("[DONE]"), + ] + ); + } + + #[test] + fn preserves_a_single_leading_space_after_data_colon() { + // Per SSE spec: a single space after `data:` is part of the + // separator, not the payload. Anything beyond that single space is + // payload — so `data: hello` carries ` hello` (one leading space). + let mut buf = String::from("data: hello\n\n"); + let out = split_sse_events(&mut buf); + assert_eq!(out, vec![String::from(" hello")]); + } +} diff --git a/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx b/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx index 6156c6d8..2a9a92e8 100644 --- a/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx +++ b/apps/desktop/src/addons/builtin.ai-assistant/AiChatComponent.tsx @@ -24,7 +24,7 @@ import { ToolApprovalPanel } from "./ToolApprovalPanel"; import { classifyToolError, formatToolError, failureKey } from "./toolErrors"; import { extractPlan } from "./planExtractor"; import { streamOllamaChat, type OllamaStreamFinalMessage } from "./ollamaStream"; -import { cloudChat, keyForProvider, type ChatMessage as ProviderChatMessage } from "@/api/providers/client"; +import { streamCloudChat, keyForProvider, type ChatMessage as ProviderChatMessage } from "@/api/providers/client"; import { PROVIDERS, PROVIDER_IDS } from "@/api/providers/registry"; type ToolArgs = Record; @@ -538,10 +538,10 @@ const AiChatComponent: React.FC = () => { abortControllerRef.current = controller; // Cloud-provider branch (issue #267). Routes the request through the - // generic `cloud_proxy` instead of Ollama's bridge. Tools / agent / - // streaming are intentionally not wired for cloud yet — every provider - // has a different SSE envelope and tool-call format. Single-shot - // chat works for all four providers we ship today. + // generic `cloud_proxy_stream` instead of Ollama's bridge. Tools / + // agent are not wired for cloud yet (each provider has its own + // tool-call shape — that lands in a follow-up); streaming text works + // for all four providers we ship today. const activeProvider = aiSettings.activeProvider ?? "ollama"; if (activeProvider !== "ollama") { try { @@ -583,26 +583,47 @@ const AiChatComponent: React.FC = () => { })), { role: "user", content: userMessage }, ]; - const result = await cloudChat({ - provider: activeProvider, - apiKey: cloudKey, - model: modelToUse, - messages: cloudHistory, - temperature: aiSettings.temperature, - maxTokens: aiSettings.maxTokens, - signal: controller.signal, - }); + + // Lazily create the assistant bubble on the first delta so we + // don't leave an empty placeholder behind if the request fails + // before any content arrives. Same pattern the Ollama branch + // uses below. + let placeholderPushed = false; + const pushPlaceholderOnce = () => { + if (placeholderPushed) return; + addMessageToSession(activeSessionId, { role: "ai", text: "" }); + placeholderPushed = true; + }; + + const result = await streamCloudChat( + { + provider: activeProvider, + apiKey: cloudKey, + model: modelToUse, + messages: cloudHistory, + temperature: aiSettings.temperature, + maxTokens: aiSettings.maxTokens, + signal: controller.signal, + }, + (delta) => { + pushPlaceholderOnce(); + appendToLastAiMessage(activeSessionId, delta); + }, + ); if (controller.signal.aborted) return; - if (!result.ok) { + if (!result.ok && !placeholderPushed) { addMessageToSession(activeSessionId, { role: "ai", text: `Error: ${result.error || "cloud provider returned no content"}`, }); - } else { - addMessageToSession(activeSessionId, { - role: "ai", - text: result.text, - }); + } else if (!result.ok) { + // Stream started, then failed mid-flight — append the error + // inline so the user can see how far the model got before the + // failure. The bubble already exists so we don't push a new one. + appendToLastAiMessage( + activeSessionId, + `\n\n[Error: ${result.error || "stream interrupted"}]`, + ); } } catch (err) { if (!controller.signal.aborted) { diff --git a/apps/desktop/src/api/providers/client.test.ts b/apps/desktop/src/api/providers/client.test.ts new file mode 100644 index 00000000..284082c8 --- /dev/null +++ b/apps/desktop/src/api/providers/client.test.ts @@ -0,0 +1,113 @@ +import { describe, expect, it } from "vitest"; +import { extractStreamDelta, extractFullResponse } from "./client"; + +describe("extractStreamDelta", () => { + it("returns empty string for the [DONE] sentinel and non-JSON input", () => { + for (const provider of ["openai", "anthropic", "gemini", "openrouter"] as const) { + expect(extractStreamDelta(provider, "[DONE]")).toBe(""); + expect(extractStreamDelta(provider, "")).toBe(""); + expect(extractStreamDelta(provider, "not-json")).toBe(""); + } + }); + + it("extracts choices[0].delta.content for OpenAI / OpenRouter", () => { + const data = JSON.stringify({ + choices: [{ delta: { content: "hello" } }], + }); + expect(extractStreamDelta("openai", data)).toBe("hello"); + expect(extractStreamDelta("openrouter", data)).toBe("hello"); + }); + + it("returns empty for OpenAI events without delta.content (role-only chunks)", () => { + const roleOnly = JSON.stringify({ + choices: [{ delta: { role: "assistant" } }], + }); + expect(extractStreamDelta("openai", roleOnly)).toBe(""); + }); + + it("extracts content_block_delta.text_delta for Anthropic", () => { + const data = JSON.stringify({ + type: "content_block_delta", + delta: { type: "text_delta", text: "world" }, + }); + expect(extractStreamDelta("anthropic", data)).toBe("world"); + }); + + it("ignores Anthropic housekeeping events (message_start, ping, message_stop)", () => { + expect( + extractStreamDelta("anthropic", JSON.stringify({ type: "ping" })), + ).toBe(""); + expect( + extractStreamDelta("anthropic", JSON.stringify({ type: "message_start" })), + ).toBe(""); + expect( + extractStreamDelta("anthropic", JSON.stringify({ type: "message_stop" })), + ).toBe(""); + }); + + it("ignores Anthropic content_block_delta with non-text_delta variants", () => { + // Future-proof: tool_use deltas show up under content_block_delta too + // and must NOT be appended as if they were chat text. + const toolUse = JSON.stringify({ + type: "content_block_delta", + delta: { type: "input_json_delta", partial_json: "{...}" }, + }); + expect(extractStreamDelta("anthropic", toolUse)).toBe(""); + }); + + it("extracts candidates[0].content.parts[*].text for Gemini", () => { + const data = JSON.stringify({ + candidates: [ + { content: { parts: [{ text: "foo" }, { text: " bar" }] } }, + ], + }); + expect(extractStreamDelta("gemini", data)).toBe("foo bar"); + }); + + it("returns empty for Gemini events without text parts", () => { + expect( + extractStreamDelta( + "gemini", + JSON.stringify({ candidates: [{ finishReason: "STOP" }] }), + ), + ).toBe(""); + }); +}); + +describe("extractFullResponse", () => { + it("extracts choices[0].message.content for OpenAI / OpenRouter", () => { + const body = JSON.stringify({ + choices: [{ message: { content: "complete reply" } }], + }); + expect(extractFullResponse("openai", body)).toBe("complete reply"); + expect(extractFullResponse("openrouter", body)).toBe("complete reply"); + }); + + it("joins all content[*].text blocks for Anthropic, ignoring tool_use", () => { + const body = JSON.stringify({ + content: [ + { type: "text", text: "alpha " }, + { type: "tool_use", id: "x", name: "y", input: {} }, + { type: "text", text: "beta" }, + ], + }); + expect(extractFullResponse("anthropic", body)).toBe("alpha beta"); + }); + + it("joins all candidates[*].content.parts[*].text blocks for Gemini", () => { + const body = JSON.stringify({ + candidates: [ + { content: { parts: [{ text: "one " }, { text: "two" }] } }, + { content: { parts: [{ text: " three" }] } }, + ], + }); + expect(extractFullResponse("gemini", body)).toBe("one two three"); + }); + + it("returns empty string for non-JSON or empty bodies", () => { + for (const provider of ["openai", "anthropic", "gemini", "openrouter"] as const) { + expect(extractFullResponse(provider, "")).toBe(""); + expect(extractFullResponse(provider, "garbage")).toBe(""); + } + }); +}); diff --git a/apps/desktop/src/api/providers/client.ts b/apps/desktop/src/api/providers/client.ts index d53aeb31..ffc9f3a7 100644 --- a/apps/desktop/src/api/providers/client.ts +++ b/apps/desktop/src/api/providers/client.ts @@ -1,5 +1,7 @@ "use client"; +import { listen, type UnlistenFn } from "@tauri-apps/api/event"; +import { invoke as tauriInvoke } from "@tauri-apps/api/core"; import { safeInvoke as invoke } from "@/api/tauri"; import type { ProviderId, ProviderKeys } from "@/context/SettingsContext"; import { logger } from "@/lib/logger"; @@ -25,14 +27,125 @@ export interface CloudChatResult { error?: string; } +interface ProviderRequest { + url: string; + headers: Array<[string, string]>; + body: unknown; +} + /** - * Single-shot non-streaming chat for cloud providers. Streaming for - * cloud is a follow-up — every provider here uses a different SSE - * envelope and the Rust side currently only has a streaming bridge for - * Ollama. Returning the whole reply once the request resolves lets us - * wire all four providers without a per-provider streaming bridge. + * Build the URL / headers / body for a provider chat call. The `stream` + * flag toggles the streaming endpoint (Gemini changes URL; the others + * just set `stream: true` in the body) so the same builder backs both + * `cloudChat` and `streamCloudChat`. */ -export async function cloudChat(req: CloudChatRequest): Promise { +function buildProviderRequest( + req: CloudChatRequest, + stream: boolean, +): ProviderRequest { + switch (req.provider) { + case "openai": + return { + url: "https://api.openai.com/v1/chat/completions", + headers: [ + ["Authorization", `Bearer ${req.apiKey}`], + ["Content-Type", "application/json"], + ], + body: { + model: req.model, + messages: req.messages, + temperature: req.temperature ?? 0.7, + max_tokens: req.maxTokens ?? 2048, + stream, + }, + }; + case "openrouter": + return { + url: "https://openrouter.ai/api/v1/chat/completions", + headers: [ + ["Authorization", `Bearer ${req.apiKey}`], + ["Content-Type", "application/json"], + ["HTTP-Referer", "https://github.com/TrixtyAI/ide"], + ["X-Title", "Trixty IDE"], + ], + body: { + model: req.model, + messages: req.messages, + temperature: req.temperature ?? 0.7, + max_tokens: req.maxTokens ?? 2048, + stream, + }, + }; + case "anthropic": { + // Anthropic separates the system prompt from the messages array. + const systemMessages = req.messages.filter((m) => m.role === "system"); + const conversation = req.messages.filter((m) => m.role !== "system"); + const system = systemMessages + .map((m) => m.content) + .join("\n\n") + .trim(); + return { + url: "https://api.anthropic.com/v1/messages", + headers: [ + ["x-api-key", req.apiKey], + ["anthropic-version", "2023-06-01"], + ["Content-Type", "application/json"], + ], + body: { + model: req.model, + max_tokens: req.maxTokens ?? 2048, + temperature: req.temperature ?? 0.7, + system: system || undefined, + messages: conversation.map((m) => ({ + role: m.role === "assistant" ? "assistant" : "user", + content: m.content, + })), + stream, + }, + }; + } + case "gemini": { + // Gemini supports the API key either as a query param or the + // `x-goog-api-key` header. We use the header form so the key never + // shows up in URL logs (Tauri's `e.to_string()` on a transport + // failure echoes the URL, OS-level proxies log query strings, etc.). + const path = stream ? "streamGenerateContent?alt=sse" : "generateContent"; + const url = `https://generativelanguage.googleapis.com/v1beta/models/${encodeURIComponent( + req.model, + )}:${path}`; + const systemMessages = req.messages.filter((m) => m.role === "system"); + const conversation = req.messages.filter((m) => m.role !== "system"); + const systemInstruction = systemMessages.length + ? { + role: "user", + parts: [ + { text: systemMessages.map((m) => m.content).join("\n\n") }, + ], + } + : undefined; + return { + url, + headers: [ + ["x-goog-api-key", req.apiKey], + ["Content-Type", "application/json"], + ], + body: { + contents: conversation.map((m) => ({ + role: m.role === "assistant" ? "model" : "user", + parts: [{ text: m.content }], + })), + ...(systemInstruction ? { systemInstruction } : {}), + generationConfig: { + temperature: req.temperature ?? 0.7, + maxOutputTokens: req.maxTokens ?? 2048, + }, + }, + }; + } + } +} + +function validateRequest(req: CloudChatRequest): CloudChatResult | null { if (req.signal?.aborted) return { ok: false, text: "", error: "aborted" }; if (!req.apiKey) { return { @@ -48,24 +161,40 @@ export async function cloudChat(req: CloudChatRequest): Promise error: `No model selected for ${req.provider}`, }; } + return null; +} + +/** + * Single-shot non-streaming chat for cloud providers. Kept for callers + * that need the full reply atomically (e.g. background tool-summary + * jobs). For interactive chat use `streamCloudChat`, which threads + * tokens into the UI as they arrive. + */ +export async function cloudChat(req: CloudChatRequest): Promise { + const invalid = validateRequest(req); + if (invalid) return invalid; try { - switch (req.provider) { - case "openai": - return await chatOpenAICompatible( - "https://api.openai.com/v1/chat/completions", - req, - ); - case "openrouter": - return await chatOpenAICompatible( - "https://openrouter.ai/api/v1/chat/completions", - req, - ); - case "anthropic": - return await chatAnthropic(req); - case "gemini": - return await chatGemini(req); + const config = buildProviderRequest(req, false); + const result = await invoke( + "cloud_proxy", + { + method: "POST", + url: config.url, + headers: config.headers, + body: config.body, + }, + { silent: true }, + ); + if (result.status < 200 || result.status >= 300) { + return { + ok: false, + text: "", + error: `${req.provider} HTTP ${result.status}: ${truncate(result.body, 240)}`, + }; } + const text = extractFullResponse(req.provider, result.body); + return { ok: text.length > 0, text }; } catch (err) { if (req.signal?.aborted) return { ok: false, text: "", error: "aborted" }; logger.warn(`[providers/${req.provider}] chat failed:`, err); @@ -74,153 +203,200 @@ export async function cloudChat(req: CloudChatRequest): Promise } /** - * OpenAI / OpenRouter share the `/v1/chat/completions` shape. OpenRouter - * also requires the standard `Authorization: Bearer KEY` header — the - * `HTTP-Referer` and `X-Title` headers are optional metadata that - * surface the app name in OpenRouter's dashboards. + * Streaming chat for cloud providers. Emits each token to `onDelta` as + * it arrives and resolves with the full concatenated text on completion. + * The Rust `cloud_proxy_stream` command pumps SSE events back through + * the `cloud-stream` Tauri event keyed by a UUID `streamId`. + * + * Cancellation: if `req.signal` aborts mid-stream the helper fires + * `cloud_proxy_cancel` so the tokio task tears down before more chunks + * arrive, then re-throws an `AbortError` so callers can branch the same + * way they do for `streamOllamaChat`. */ -async function chatOpenAICompatible( - url: string, +export async function streamCloudChat( req: CloudChatRequest, + onDelta: (text: string) => void, ): Promise { - const headers: Array<[string, string]> = [ - ["Authorization", `Bearer ${req.apiKey}`], - ["Content-Type", "application/json"], - ]; - if (req.provider === "openrouter") { - headers.push(["HTTP-Referer", "https://github.com/TrixtyAI/ide"]); - headers.push(["X-Title", "Trixty IDE"]); + const invalid = validateRequest(req); + if (invalid) return invalid; + + const config = buildProviderRequest(req, true); + const streamId = crypto.randomUUID(); + + let fullText = ""; + let errorText: string | undefined; + let unlisten: UnlistenFn | undefined; + let aborted = false; + + const settled = new Promise((resolve, reject) => { + listen("cloud-stream", (event) => { + const payload = event.payload; + if (payload.streamId !== streamId) return; + if (payload.kind === "data") { + const delta = extractStreamDelta(req.provider, payload.data ?? ""); + if (delta) { + fullText += delta; + onDelta(delta); + } + return; + } + if (payload.kind === "done") { + resolve(); + return; + } + if (payload.kind === "error") { + errorText = payload.error ?? "Unknown streaming error"; + reject(new Error(errorText)); + } + }).then((u) => { + unlisten = u; + }); + }); + + const onAbort = () => { + aborted = true; + tauriInvoke("cloud_proxy_cancel", { streamId }).catch((err) => { + logger.debug("[providers/cloud] cancel failed:", err); + }); + }; + if (req.signal?.aborted) { + onAbort(); + } else if (req.signal) { + req.signal.addEventListener("abort", onAbort, { once: true }); } - const result = await invoke( - "cloud_proxy", - { + + try { + await tauriInvoke("cloud_proxy_stream", { + streamId, method: "POST", - url, - headers, - body: { - model: req.model, - messages: req.messages, - temperature: req.temperature ?? 0.7, - max_tokens: req.maxTokens ?? 2048, - }, - }, - { silent: true }, - ); - if (result.status < 200 || result.status >= 300) { + url: config.url, + headers: config.headers, + body: config.body, + }); + await settled; + return { ok: fullText.length > 0, text: fullText }; + } catch (err) { + if (aborted) { + const abortError = new Error("Aborted"); + abortError.name = "AbortError"; + throw abortError; + } return { ok: false, - text: "", - error: `${req.provider} HTTP ${result.status}: ${truncate(result.body, 240)}`, + text: fullText, + error: errorText ?? (err instanceof Error ? err.message : String(err)), }; + } finally { + if (unlisten) unlisten(); + if (req.signal) req.signal.removeEventListener("abort", onAbort); } - const parsed = JSON.parse(result.body) as { - choices?: { message?: { content?: string } }[]; - }; - const text = parsed.choices?.[0]?.message?.content ?? ""; - return { ok: text.length > 0, text }; } -async function chatAnthropic(req: CloudChatRequest): Promise { - // Anthropic separates the system prompt from the messages array. Pull - // any leading system message into the dedicated `system` field. - const systemMessages = req.messages.filter((m) => m.role === "system"); - const conversation = req.messages.filter((m) => m.role !== "system"); - const system = systemMessages.map((m) => m.content).join("\n\n").trim(); - const result = await invoke( - "cloud_proxy", - { - method: "POST", - url: "https://api.anthropic.com/v1/messages", - headers: [ - ["x-api-key", req.apiKey], - ["anthropic-version", "2023-06-01"], - ["Content-Type", "application/json"], - ], - body: { - model: req.model, - max_tokens: req.maxTokens ?? 2048, - temperature: req.temperature ?? 0.7, - system: system || undefined, - messages: conversation.map((m) => ({ - role: m.role === "assistant" ? "assistant" : "user", - content: m.content, - })), - }, - }, - { silent: true }, - ); - if (result.status < 200 || result.status >= 300) { - return { - ok: false, - text: "", - error: `anthropic HTTP ${result.status}: ${truncate(result.body, 240)}`, - }; +interface CloudStreamEvent { + streamId: string; + kind: "data" | "done" | "error"; + data?: string; + error?: string; +} + +/** + * Per-provider parse of one SSE `data:` payload into the delta text to + * append. Returns `""` for keep-alive / housekeeping events (Anthropic + * `ping`, message_start / stop, etc.) so the caller can ignore them. + */ +export function extractStreamDelta( + provider: Exclude, + raw: string, +): string { + if (!raw || raw === "[DONE]") return ""; + let parsed: unknown; + try { + parsed = JSON.parse(raw); + } catch { + return ""; + } + if (typeof parsed !== "object" || parsed === null) return ""; + const obj = parsed as Record; + + switch (provider) { + case "openai": + case "openrouter": { + const choices = obj.choices as + | Array<{ delta?: { content?: string } }> + | undefined; + return choices?.[0]?.delta?.content ?? ""; + } + case "anthropic": { + if (obj.type !== "content_block_delta") return ""; + const delta = obj.delta as + | { type?: string; text?: string } + | undefined; + if (delta?.type !== "text_delta") return ""; + return delta.text ?? ""; + } + case "gemini": { + const candidates = obj.candidates as + | Array<{ content?: { parts?: Array<{ text?: string }> } }> + | undefined; + return ( + candidates + ?.flatMap((c) => c.content?.parts ?? []) + .map((p) => p.text ?? "") + .join("") ?? "" + ); + } } - const parsed = JSON.parse(result.body) as { - content?: { type: string; text?: string }[]; - }; - const text = (parsed.content ?? []) - .filter((b) => b.type === "text") - .map((b) => b.text ?? "") - .join(""); - return { ok: text.length > 0, text }; } -async function chatGemini(req: CloudChatRequest): Promise { - // Gemini supports the API key either as a query param or the - // `x-goog-api-key` header. We use the header form so the key never - // shows up in URL logs (Tauri's `e.to_string()` on a transport - // failure echoes the URL, OS-level proxies log query strings, etc.). - // Body uses Gemini's `contents` shape with `user` / `model` roles. - const url = `https://generativelanguage.googleapis.com/v1beta/models/${encodeURIComponent( - req.model, - )}:generateContent`; - const systemMessages = req.messages.filter((m) => m.role === "system"); - const conversation = req.messages.filter((m) => m.role !== "system"); - const systemInstruction = systemMessages.length - ? { - role: "user", - parts: [{ text: systemMessages.map((m) => m.content).join("\n\n") }], - } - : undefined; - const result = await invoke( - "cloud_proxy", - { - method: "POST", - url, - headers: [ - ["x-goog-api-key", req.apiKey], - ["Content-Type", "application/json"], - ], - body: { - contents: conversation.map((m) => ({ - role: m.role === "assistant" ? "model" : "user", - parts: [{ text: m.content }], - })), - ...(systemInstruction ? { systemInstruction } : {}), - generationConfig: { - temperature: req.temperature ?? 0.7, - maxOutputTokens: req.maxTokens ?? 2048, - }, - }, - }, - { silent: true }, - ); - if (result.status < 200 || result.status >= 300) { - return { - ok: false, - text: "", - error: `gemini HTTP ${result.status}: ${truncate(result.body, 240)}`, - }; +/** + * Per-provider parse of a non-streaming JSON body. Mirror of + * `extractStreamDelta`'s switch but for the full-response shape. Kept + * exported so unit tests can hit each branch without the proxy bridge. + */ +export function extractFullResponse( + provider: Exclude, + body: string, +): string { + let parsed: unknown; + try { + parsed = JSON.parse(body); + } catch { + return ""; + } + if (typeof parsed !== "object" || parsed === null) return ""; + const obj = parsed as Record; + + switch (provider) { + case "openai": + case "openrouter": { + const choices = obj.choices as + | Array<{ message?: { content?: string } }> + | undefined; + return choices?.[0]?.message?.content ?? ""; + } + case "anthropic": { + const content = obj.content as + | Array<{ type: string; text?: string }> + | undefined; + return ( + content + ?.filter((b) => b.type === "text") + .map((b) => b.text ?? "") + .join("") ?? "" + ); + } + case "gemini": { + const candidates = obj.candidates as + | Array<{ content?: { parts?: Array<{ text?: string }> } }> + | undefined; + return ( + candidates + ?.flatMap((c) => c.content?.parts ?? []) + .map((p) => p.text ?? "") + .join("") ?? "" + ); + } } - const parsed = JSON.parse(result.body) as { - candidates?: { content?: { parts?: { text?: string }[] } }[]; - }; - const text = (parsed.candidates ?? []) - .flatMap((c) => c.content?.parts ?? []) - .map((p) => p.text ?? "") - .join(""); - return { ok: text.length > 0, text }; } function truncate(s: string, n: number): string { diff --git a/apps/desktop/src/api/tauri.ts b/apps/desktop/src/api/tauri.ts index f35ad30f..f88d5133 100644 --- a/apps/desktop/src/api/tauri.ts +++ b/apps/desktop/src/api/tauri.ts @@ -107,6 +107,17 @@ export interface TauriInvokeMap { }; return: { status: number; body: string }; }; + "cloud_proxy_stream": { + args: { + streamId: string; + method: string; + url: string; + headers?: Array<[string, string]>; + body?: unknown; + }; + return: void; + }; + "cloud_proxy_cancel": { args: { streamId: string }; return: void }; "check_update": { args: { channel?: "stable" | "pre-release" }; return: { version: string; body?: string | null } | null }; "install_update": { args: { channel?: "stable" | "pre-release" }; return: void }; "spawn_pty": { args: { sessionId: string; cwd?: string; rows?: number; cols?: number }; return: void };