From 39b5387dc24f2fb38122254d4e582540cec7ed49 Mon Sep 17 00:00:00 2001 From: Jordan Ritter Date: Thu, 23 Apr 2026 16:22:57 -0700 Subject: [PATCH 1/6] fix: server core hardening and type system updates readBody 10MB size limit, control API error detail, one-shot race fix, normalizeCompatPath clarity, fixtures_loaded gauge updates on mutations, VideoStateMap TTL/bounds. JournalEntry source field, FixtureResponse type widening, matchesPattern JSDoc, isContentWithToolCallsResponse re-export. --- src/__tests__/chaos-fixture-mode.test.ts | 161 ---- src/__tests__/control-api.test.ts | 5 +- src/__tests__/metrics.test.ts | 44 +- src/__tests__/server.test.ts | 11 - src/chaos.ts | 46 +- src/helpers.ts | 9 + src/index.ts | 10 +- src/journal.ts | 3 +- src/server.ts | 1082 +++++++++++----------- src/types.ts | 17 +- 10 files changed, 569 insertions(+), 819 deletions(-) delete mode 100644 src/__tests__/chaos-fixture-mode.test.ts diff --git a/src/__tests__/chaos-fixture-mode.test.ts b/src/__tests__/chaos-fixture-mode.test.ts deleted file mode 100644 index 0411c26..0000000 --- a/src/__tests__/chaos-fixture-mode.test.ts +++ /dev/null @@ -1,161 +0,0 @@ -import { describe, it, expect, afterEach } from "vitest"; -import { createServer, type ServerInstance } from "../server.js"; - -// minimal helpers duplicated to keep this test isolated -import * as http from "node:http"; - -function post(url: string, body: unknown): Promise<{ status: number; body: string }> { - return new Promise((resolve, reject) => { - const data = JSON.stringify(body); - const parsed = new URL(url); - const req = http.request( - { - hostname: parsed.hostname, - port: parsed.port, - path: parsed.pathname, - method: "POST", - headers: { - "Content-Type": "application/json", - "Content-Length": Buffer.byteLength(data), - }, - }, - (res) => { - const chunks: Buffer[] = []; - res.on("data", (c: Buffer) => chunks.push(c)); - res.on("end", () => { - resolve({ status: res.statusCode ?? 0, body: Buffer.concat(chunks).toString() }); - }); - }, - ); - req.on("error", reject); - req.write(data); - req.end(); - }); -} - -function get(url: string): Promise<{ status: number; body: string }> { - return new Promise((resolve, reject) => { - const parsed = new URL(url); - const req = http.request( - { - hostname: parsed.hostname, - port: parsed.port, - path: parsed.pathname + parsed.search, - method: "GET", - }, - (res) => { - const chunks: Buffer[] = []; - res.on("data", (c: Buffer) => chunks.push(c)); - res.on("end", () => { - resolve({ status: res.statusCode ?? 0, body: Buffer.concat(chunks).toString() }); - }); - }, - ); - req.on("error", reject); - req.end(); - }); -} - -let server: ServerInstance | undefined; - -afterEach(async () => { - if (server) { - await new Promise((resolve) => server!.server.close(() => resolve())); - server = undefined; - } -}); - -const CHAT_REQUEST = { - model: "gpt-4", - messages: [{ role: "user", content: "What is the capital of France?" }], -}; - -describe("chaos (fixture mode)", () => { - it("chaos short-circuits even when fixture would match", async () => { - const fixture = { - match: { userMessage: "capital of France" }, - response: { content: "Paris" }, - }; - - server = await createServer([fixture], { - port: 0, - chaos: { dropRate: 1.0 }, - }); - - const resp = await post(`${server.url}/v1/chat/completions`, CHAT_REQUEST); - - expect(resp.status).toBe(500); - const body = JSON.parse(resp.body); - expect(body).toMatchObject({ error: { code: "chaos_drop" } }); - }); - - it("rolls chaos once per request: drop journals the matched fixture, not null", async () => { - // Pins the single-roll behavior: chaos evaluation happens AFTER fixture - // matching, so when drop fires on a request that matches a fixture, the - // journal entry reflects the match (not null, as the old double-roll - // pre-flight path would have recorded). - const fixture = { - match: { userMessage: "capital of France" }, - response: { content: "Paris" }, - }; - - server = await createServer([fixture], { - port: 0, - chaos: { dropRate: 1.0 }, - }); - - const resp = await post(`${server.url}/v1/chat/completions`, CHAT_REQUEST); - expect(resp.status).toBe(500); - - const last = server.journal.getLast(); - expect(last?.response.chaosAction).toBe("drop"); - expect(last?.response.fixture).toBe(fixture); - // Match count reflects that the fixture did participate in the decision - expect(server.journal.getFixtureMatchCount(fixture)).toBe(1); - }); - - it("disconnect journals the matched fixture with status 0", async () => { - // Symmetric to the drop test above. Disconnect's status is 0 (no response - // ever written before res.destroy()) which is a slightly unusual shape; - // pin it so future refactors don't silently change it to e.g. 500. - const fixture = { - match: { userMessage: "capital of France" }, - response: { content: "Paris" }, - }; - - server = await createServer([fixture], { - port: 0, - chaos: { disconnectRate: 1.0 }, - }); - - // Client sees a socket destroy mid-request → post() rejects - await expect(post(`${server.url}/v1/chat/completions`, CHAT_REQUEST)).rejects.toThrow(); - - const last = server.journal.getLast(); - expect(last?.response.chaosAction).toBe("disconnect"); - expect(last?.response.status).toBe(0); - expect(last?.response.fixture).toBe(fixture); - expect(server.journal.getFixtureMatchCount(fixture)).toBe(1); - }); - - it("handleVideoStatus: chaos drop fires before video-not-found 404", async () => { - // Without any video state stored, a normal GET /v1/videos/ would - // return 404. With dropRate: 1.0 chaos should fire first, returning the - // 500 chaos_drop response instead. - server = await createServer([], { - port: 0, - chaos: { dropRate: 1.0 }, - }); - - const resp = await get(`${server.url}/v1/videos/test-video-id`); - - expect(resp.status).toBe(500); - const body = JSON.parse(resp.body); - expect(body).toMatchObject({ error: { code: "chaos_drop" } }); - - // Journal records the chaos action, not the 404 - const last = server.journal.getLast(); - expect(last?.response.chaosAction).toBe("drop"); - expect(last?.response.status).toBe(500); - }); -}); diff --git a/src/__tests__/control-api.test.ts b/src/__tests__/control-api.test.ts index 50a502a..12e91f4 100644 --- a/src/__tests__/control-api.test.ts +++ b/src/__tests__/control-api.test.ts @@ -142,7 +142,7 @@ describe("/__aimock control API", () => { const res = await httpRaw(`${instance.url}/__aimock/fixtures`, "POST", "not json{{{"); expect(res.status).toBe(400); const body = JSON.parse(res.body); - expect(body.error).toBe("Invalid JSON"); + expect(body.error).toMatch(/^Invalid JSON:/); }); it("returns 400 when fixtures array is missing", async () => { @@ -256,9 +256,6 @@ describe("/__aimock control API", () => { ); expect(errRes.status).toBe(429); - // Wait for queueMicrotask to clean up the one-shot fixture - await new Promise((r) => setTimeout(r, 50)); - // Second request should succeed normally const okRes = await httpRequest( `${instance.url}/v1/chat/completions`, diff --git a/src/__tests__/metrics.test.ts b/src/__tests__/metrics.test.ts index 52c366d..9cd9434 100644 --- a/src/__tests__/metrics.test.ts +++ b/src/__tests__/metrics.test.ts @@ -549,7 +549,7 @@ describe("integration: /metrics endpoint", () => { expect(infMatch![1]).toBe(countMatch![1]); }); - it("increments chaos counter when chaos triggers (fixture source)", async () => { + it("increments chaos counter when chaos triggers", async () => { const fixtures: Fixture[] = [ { match: { userMessage: "hello" }, @@ -565,43 +565,7 @@ describe("integration: /metrics endpoint", () => { const res = await httpGet(`${instance.url}/metrics`); expect(res.body).toContain("aimock_chaos_triggered_total"); - // Require both labels: action AND source. The source label is part of the - // public metric contract (added when chaos was extended to proxy mode) and - // an unasserted label is a regression hazard — future callers that forget - // to pass source would serialize `source=""` and pass a bare action match. - expect(res.body).toMatch( - /aimock_chaos_triggered_total\{[^}]*action="drop"[^}]*source="fixture"[^}]*\} 1/, - ); - }); - - it('chaos counter carries source="proxy" on proxy path', async () => { - // Counterpart to the fixture-source test: proves the source label flips - // correctly when the chaos roll belongs to the proxy dispatch branch. - // Together these two tests pin both label values of the source dimension. - const upstream = await createServer( - [{ match: { userMessage: "hi" }, response: { content: "upstream" } }], - { port: 0 }, - ); - try { - instance = await createServer([], { - metrics: true, - chaos: { dropRate: 1.0 }, - record: { - providers: { openai: upstream.url }, - fixturePath: "/tmp/aimock-metrics-proxy-source", - proxyOnly: true, - }, - }); - - await httpPost(`${instance.url}/v1/chat/completions`, chatRequest("hi")); - - const res = await httpGet(`${instance.url}/metrics`); - expect(res.body).toMatch( - /aimock_chaos_triggered_total\{[^}]*action="drop"[^}]*source="proxy"[^}]*\} 1/, - ); - } finally { - await new Promise((resolve) => upstream.server.close(() => resolve())); - } + expect(res.body).toMatch(/aimock_chaos_triggered_total\{[^}]*action="drop"[^}]*\} 1/); }); it("increments chaos counter on Anthropic /v1/messages endpoint", async () => { @@ -624,9 +588,7 @@ describe("integration: /metrics endpoint", () => { const res = await httpGet(`${instance.url}/metrics`); expect(res.body).toContain("aimock_chaos_triggered_total"); - expect(res.body).toMatch( - /aimock_chaos_triggered_total\{[^}]*action="drop"[^}]*source="fixture"[^}]*\} 1/, - ); + expect(res.body).toMatch(/aimock_chaos_triggered_total\{[^}]*action="drop"[^}]*\} 1/); }); it("tracks fixtures loaded gauge", async () => { diff --git a/src/__tests__/server.test.ts b/src/__tests__/server.test.ts index 1581081..eae99ed 100644 --- a/src/__tests__/server.test.ts +++ b/src/__tests__/server.test.ts @@ -609,17 +609,6 @@ describe("CORS", () => { expect(res.headers["access-control-allow-methods"]).toContain("POST"); }); - it("OPTIONS preflight includes chaos control headers", async () => { - instance = await createServer(allFixtures); - const res = await options(`${instance.url}/v1/chat/completions`); - - const allowHeaders = res.headers["access-control-allow-headers"] ?? ""; - expect(allowHeaders).toContain("X-Aimock-Chaos-Drop"); - expect(allowHeaders).toContain("X-Aimock-Chaos-Malformed"); - expect(allowHeaders).toContain("X-Aimock-Chaos-Disconnect"); - expect(allowHeaders).toContain("X-Test-Id"); - }); - it("includes CORS headers on 404 responses", async () => { instance = await createServer(allFixtures); const res = await post(`${instance.url}/v1/chat/completions`, { diff --git a/src/chaos.ts b/src/chaos.ts index a2b8be0..f30b927 100644 --- a/src/chaos.ts +++ b/src/chaos.ts @@ -131,16 +131,12 @@ interface ChaosJournalContext { method: string; path: string; headers: Record; - body: ChatCompletionRequest | null; + body: ChatCompletionRequest; } /** * Apply chaos to a request. Returns true if chaos was applied (caller should * return early), false if the request should proceed normally. - * - * `source` is required so the invariant "this handler only applies chaos in - * the phase" is enforced at the type level. A future handler that grows - * a proxy path MUST pass `"proxy"` explicitly; the default can't drift silently. */ export function applyChaos( res: http.ServerResponse, @@ -149,45 +145,21 @@ export function applyChaos( rawHeaders: http.IncomingHttpHeaders, journal: Journal, context: ChaosJournalContext, - source: "fixture" | "proxy", registry?: MetricsRegistry, logger?: Logger, ): boolean { const action = evaluateChaos(fixture, serverDefaults, rawHeaders, logger); if (!action) return false; - applyChaosAction(action, res, fixture, journal, context, source, registry); - return true; -} -/** - * Apply a specific (already-rolled) chaos action. Exposed so callers that roll - * the dice themselves can dispatch without re-rolling — important when the - * caller wants to branch on the action before committing (e.g. pre-flight vs. - * post-response phases). - * - * `source` is required (not optional) so callers can't silently omit it on - * one branch and journal an ambiguous entry. Pass `"fixture"` when a fixture - * matched (or would have) and `"proxy"` when the request was headed for the - * proxy path. - */ -export function applyChaosAction( - action: ChaosAction, - res: http.ServerResponse, - fixture: Fixture | null, - journal: Journal, - context: ChaosJournalContext, - source: "fixture" | "proxy", - registry?: MetricsRegistry, -): void { if (registry) { - registry.incrementCounter("aimock_chaos_triggered_total", { action, source }); + registry.incrementCounter("aimock_chaos_triggered_total", { action }); } switch (action) { case "drop": { journal.add({ ...context, - response: { status: 500, fixture, chaosAction: "drop", source }, + response: { status: 500, fixture, chaosAction: "drop" }, }); writeErrorResponse( res, @@ -200,29 +172,29 @@ export function applyChaosAction( }, }), ); - return; + return true; } case "malformed": { journal.add({ ...context, - response: { status: 200, fixture, chaosAction: "malformed", source }, + response: { status: 200, fixture, chaosAction: "malformed" }, }); res.writeHead(200, { "Content-Type": "application/json" }); res.end("{malformed json: <<>>"); - return; + return true; } case "disconnect": { journal.add({ ...context, - response: { status: 0, fixture, chaosAction: "disconnect", source }, + response: { status: 0, fixture, chaosAction: "disconnect" }, }); res.destroy(); - return; + return true; } default: { const _exhaustive: never = action; void _exhaustive; - return; + return false; } } } diff --git a/src/helpers.ts b/src/helpers.ts index 799902f..2b654a6 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -554,6 +554,15 @@ export function readBody(req: http.IncomingMessage): Promise { // ─── Pattern matching ───────────────────────────────────────────────────── +/** + * Case-insensitive substring/regex match used for search, rerank, and + * moderation endpoints where exact casing rarely matters. String patterns + * are lowercased on both sides before comparison. + * + * Note: This intentionally differs from the case-sensitive matching in + * {@link matchFixture} (router.ts), where fixture authors expect exact + * string matching against chat completion user messages. + */ export function matchesPattern(text: string, pattern: string | RegExp): boolean { if (typeof pattern === "string") { return text.toLowerCase().includes(pattern.toLowerCase()); diff --git a/src/index.ts b/src/index.ts index e69602d..046908e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -24,7 +24,12 @@ export { Journal, DEFAULT_TEST_ID } from "./journal.js"; export { matchFixture, getTextContent } from "./router.js"; // Provider handlers -export { handleResponses, buildTextStreamEvents, buildToolCallStreamEvents } from "./responses.js"; +export { + handleResponses, + buildTextStreamEvents, + buildToolCallStreamEvents, + buildContentWithToolCallsStreamEvents, +} from "./responses.js"; export type { ResponsesSSEEvent } from "./responses.js"; export { handleMessages } from "./messages.js"; export { handleGemini } from "./gemini.js"; @@ -78,8 +83,7 @@ export { handleWebSocketGeminiLive } from "./ws-gemini-live.js"; export { handleImages } from "./images.js"; export { handleSpeech } from "./speech.js"; export { handleTranscription } from "./transcription.js"; -export { handleVideoCreate, handleVideoStatus } from "./video.js"; -export type { VideoStateMap } from "./video.js"; +export { handleVideoCreate, handleVideoStatus, VideoStateMap } from "./video.js"; // Helpers export { diff --git a/src/journal.ts b/src/journal.ts index 9b0611f..b00fbd6 100644 --- a/src/journal.ts +++ b/src/journal.ts @@ -25,7 +25,8 @@ function matchCriteriaEqual(a: FixtureMatch, b: FixtureMatch): boolean { fieldEqual(a.toolName, b.toolName) && fieldEqual(a.model, b.model) && fieldEqual(a.responseFormat, b.responseFormat) && - fieldEqual(a.predicate, b.predicate) + fieldEqual(a.predicate, b.predicate) && + fieldEqual(a.endpoint, b.endpoint) ); } diff --git a/src/server.ts b/src/server.ts index 16b0ec3..08144b5 100644 --- a/src/server.ts +++ b/src/server.ts @@ -37,7 +37,7 @@ import { handleEmbeddings } from "./embeddings.js"; import { handleImages } from "./images.js"; import { handleSpeech } from "./speech.js"; import { handleTranscription } from "./transcription.js"; -import { handleVideoCreate, handleVideoStatus, type VideoStateMap } from "./video.js"; +import { handleVideoCreate, handleVideoStatus, VideoStateMap } from "./video.js"; import { handleOllama, handleOllamaGenerate } from "./ollama.js"; import { handleCohere } from "./cohere.js"; import { handleSearch, type SearchFixture } from "./search.js"; @@ -48,7 +48,7 @@ import { handleWebSocketResponses } from "./ws-responses.js"; import { handleWebSocketRealtime } from "./ws-realtime.js"; import { handleWebSocketGeminiLive } from "./ws-gemini-live.js"; import { Logger } from "./logger.js"; -import { applyChaosAction, evaluateChaos } from "./chaos.js"; +import { applyChaos } from "./chaos.js"; import { createMetricsRegistry, normalizePathLabel } from "./metrics.js"; import { proxyAndRecord } from "./recorder.js"; @@ -102,7 +102,7 @@ const COMPAT_SUFFIXES = [ function normalizeCompatPath(pathname: string, logger?: Logger): string { // Strip /openai/ prefix (Groq/OpenAI-compat alias) if (pathname.startsWith("/openai/")) { - pathname = pathname.slice(7); + pathname = pathname.slice("/openai".length); } // Normalize arbitrary prefixes to /v1/ @@ -148,8 +148,7 @@ const DEFAULT_MODELS = [ const CORS_HEADERS: Record = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS", - "Access-Control-Allow-Headers": - "Content-Type, Authorization, X-Aimock-Chaos-Drop, X-Aimock-Chaos-Malformed, X-Aimock-Chaos-Disconnect, X-Test-Id", + "Access-Control-Allow-Headers": "Content-Type, Authorization", }; function setCorsHeaders(res: http.ServerResponse): void { @@ -158,9 +157,23 @@ function setCorsHeaders(res: http.ServerResponse): void { } } -async function readBody(req: http.IncomingMessage): Promise { +const DEFAULT_MAX_BODY_BYTES = 10 * 1024 * 1024; // 10 MB + +async function readBody( + req: http.IncomingMessage, + maxBytes: number = DEFAULT_MAX_BODY_BYTES, +): Promise { const buffers: Buffer[] = []; - for await (const chunk of req) buffers.push(chunk as Buffer); + let totalBytes = 0; + for await (const chunk of req) { + const buf = chunk as Buffer; + totalBytes += buf.length; + if (totalBytes > maxBytes) { + req.destroy(); + throw new Error(`Request body exceeded size limit of ${maxBytes} bytes`); + } + buffers.push(buf); + } return Buffer.concat(buffers).toString(); } @@ -194,6 +207,7 @@ async function handleControlAPI( fixtures: Fixture[], journal: Journal, videoStates: VideoStateMap, + defaults: HandlerDefaults, ): Promise { if (!pathname.startsWith(CONTROL_PREFIX)) return false; @@ -219,18 +233,22 @@ async function handleControlAPI( let raw: string; try { raw = await readBody(req); - } catch { + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + defaults.logger.error(`POST /__aimock/fixtures: failed to read body: ${msg}`); res.writeHead(400, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ error: "Failed to read request body" })); + res.end(JSON.stringify({ error: `Failed to read request body: ${msg}` })); return true; } let parsed: { fixtures?: FixtureFileEntry[] }; try { parsed = JSON.parse(raw) as { fixtures?: FixtureFileEntry[] }; - } catch { + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + defaults.logger.error(`POST /__aimock/fixtures: invalid JSON: ${msg}`); res.writeHead(400, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ error: "Invalid JSON" })); + res.end(JSON.stringify({ error: `Invalid JSON: ${msg}` })); return true; } @@ -250,6 +268,9 @@ async function handleControlAPI( } fixtures.push(...converted); + if (defaults.registry) { + defaults.registry.setGauge("aimock_fixtures_loaded", {}, fixtures.length); + } res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify({ added: converted.length })); return true; @@ -258,6 +279,9 @@ async function handleControlAPI( // DELETE /__aimock/fixtures — clear all fixtures if (subPath === "/fixtures" && req.method === "DELETE") { fixtures.length = 0; + if (defaults.registry) { + defaults.registry.setGauge("aimock_fixtures_loaded", {}, fixtures.length); + } res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify({ cleared: true })); return true; @@ -268,6 +292,9 @@ async function handleControlAPI( fixtures.length = 0; journal.clear(); videoStates.clear(); + if (defaults.registry) { + defaults.registry.setGauge("aimock_fixtures_loaded", {}, fixtures.length); + } res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify({ reset: true })); return true; @@ -278,18 +305,22 @@ async function handleControlAPI( let raw: string; try { raw = await readBody(req); - } catch { + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + defaults.logger.error(`POST /__aimock/error: failed to read body: ${msg}`); res.writeHead(400, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ error: "Failed to read request body" })); + res.end(JSON.stringify({ error: `Failed to read request body: ${msg}` })); return true; } let parsed: { status?: number; body?: { message?: string; type?: string; code?: string } }; try { parsed = JSON.parse(raw) as typeof parsed; - } catch { + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + defaults.logger.error(`POST /__aimock/error: invalid JSON: ${msg}`); res.writeHead(400, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ error: "Invalid JSON" })); + res.end(JSON.stringify({ error: `Invalid JSON: ${msg}` })); return true; } @@ -308,15 +339,14 @@ async function handleControlAPI( }; // Insert at front so it matches before everything else fixtures.unshift(errorFixture); - // Remove after first match + // Remove synchronously on first match to prevent race conditions where + // two concurrent requests both match before the removal fires. const original = errorFixture.match.predicate!; errorFixture.match.predicate = (req) => { const result = original(req); if (result) { - queueMicrotask(() => { - const idx = fixtures.indexOf(errorFixture); - if (idx !== -1) fixtures.splice(idx, 1); - }); + const idx = fixtures.indexOf(errorFixture); + if (idx !== -1) fixtures.splice(idx, 1); } return result; }; @@ -398,16 +428,30 @@ async function handleCompletions( return; } - const method = req.method ?? "POST"; - const path = req.url ?? COMPLETIONS_PATH; - const flatHeaders = flattenHeaders(req.headers); + // Validate messages array + if (!Array.isArray(body.messages)) { + journal.add({ + method: req.method ?? "POST", + path: req.url ?? COMPLETIONS_PATH, + headers: flattenHeaders(req.headers), + body: null, + response: { status: 400, fixture: null }, + }); + writeErrorResponse( + res, + 400, + JSON.stringify({ + error: { + message: "Missing required parameter: 'messages'", + type: "invalid_request_error", + }, + }), + ); + return; + } - // Set endpoint type once early so router/recorder and journal see it + // Match fixture body._endpointType = "chat"; - - // Match fixture first — chaos resolution depends on fixture-level overrides - // (headers > fixture.chaos > server defaults), so the fixture has to be - // known before we can roll with the right config. const testId = getTestId(req); const fixture = matchFixture( fixtures, @@ -420,88 +464,34 @@ async function handleCompletions( journal.incrementFixtureMatchCount(fixture, fixtures, testId); } - // Roll chaos once per request. Dispatch by action + path: - // drop / disconnect → apply immediately; upstream is never called and no - // response body is produced. - // malformed, fixture path → write invalid JSON instead of the fixture. - // malformed, proxy path → proxy to upstream, then swap body via the - // beforeWriteResponse hook (passed only when the - // action is malformed, so the hook doesn't need - // to re-check the action). - const chaosAction = evaluateChaos(fixture, defaults.chaos, req.headers, defaults.logger); - const chaosContext = { method, path, headers: flatHeaders, body }; - - if (chaosAction === "drop" || chaosAction === "disconnect") { - applyChaosAction( - chaosAction, - res, - fixture, - journal, - chaosContext, - fixture ? "fixture" : "proxy", - defaults.registry, - ); - return; - } + const method = req.method ?? "POST"; + const path = req.url ?? COMPLETIONS_PATH; + const flatHeaders = flattenHeaders(req.headers); - if (fixture && chaosAction === "malformed") { - applyChaosAction( - chaosAction, + // Apply chaos before normal response handling + if ( + applyChaos( res, fixture, + defaults.chaos, + req.headers, journal, - chaosContext, - "fixture", + { + method, + path, + headers: flatHeaders, + body, + }, defaults.registry, - ); + defaults.logger, + ) + ) return; - } if (!fixture) { // Try record-and-replay proxy if configured if (defaults.record && providerKey) { - // Hook is only passed when chaos wants to mutate the response. When - // it's passed, it unconditionally applies malformed + journals + tells - // proxyAndRecord to skip its default relay. The hook has no branching - // logic — that decision is made here, at the call site. - const hookOptions = - chaosAction === "malformed" - ? { - // Malformed is emitted as a hardcoded invalid-JSON body, so the - // captured upstream response isn't used here (the parameter is - // intentionally omitted rather than declared-and-ignored). - // Future dispatch (phase 3: non-JSON / streaming) will accept - // the response and branch on contentType. - beforeWriteResponse: () => { - applyChaosAction( - chaosAction, - res, - null, - journal, - chaosContext, - "proxy", - defaults.registry, - ); - return true; - }, - // SSE can't be mutated post-facto (bytes already on the wire). - // Record the bypass so the rolled action isn't invisible in - // logs / Prometheus — otherwise malformedRate: 1.0 on SSE - // traffic silently means 0%. - onHookBypassed: (reason: "sse_streamed") => { - defaults.logger.warn( - `[chaos] malformed bypassed on proxy: upstream returned SSE (${reason})`, - ); - defaults.registry?.incrementCounter("aimock_chaos_bypassed_total", { - action: "malformed", - source: "proxy", - reason, - }); - }, - } - : undefined; - - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, body, @@ -510,10 +500,8 @@ async function handleCompletions( fixtures, defaults, raw, - hookOptions, ); - if (outcome === "handled_by_hook") return; - if (outcome === "relayed") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: req.url ?? COMPLETIONS_PATH, @@ -523,7 +511,6 @@ async function handleCompletions( }); return; } - // outcome === "not_configured" — fall through to strict/404 } const strictStatus = defaults.strict ? 503 : 404; @@ -788,7 +775,7 @@ export async function createServer( maxEntries: options?.journalMaxEntries ?? 1000, fixtureCountsMaxTestIds: options?.fixtureCountsMaxTestIds ?? 500, }); - const videoStates: VideoStateMap = new Map(); + const videoStates = new VideoStateMap(); // Share journal and metrics registry with mounted services if (mounts) { @@ -860,7 +847,7 @@ export async function createServer( // Control API — must be checked before mounts and path rewrites if (pathname.startsWith(CONTROL_PREFIX)) { - await handleControlAPI(req, res, pathname, fixtures, journal, videoStates); + await handleControlAPI(req, res, pathname, fixtures, journal, videoStates, defaults); return; } @@ -992,204 +979,208 @@ export async function createServer( // POST /v1/responses — OpenAI Responses API if (pathname === RESPONSES_PATH && req.method === "POST") { - readBody(req) - .then((raw) => handleResponses(req, res, raw, fixtures, journal, defaults, setCorsHeaders)) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - try { - res.write(`event: error\ndata: ${JSON.stringify({ error: { message: msg } })}\n\n`); - } catch (writeErr) { - logger.debug("Failed to write error recovery response:", writeErr); - } - res.end(); + try { + const raw = await readBody(req); + await handleResponses(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + try { + res.write(`event: error\ndata: ${JSON.stringify({ error: { message: msg } })}\n\n`); + } catch (writeErr) { + logger.debug("Failed to write error recovery response:", writeErr); } - }); + res.end(); + } + } return; } // POST /v1/messages — Anthropic Claude Messages API if (pathname === MESSAGES_PATH && req.method === "POST") { - readBody(req) - .then((raw) => handleMessages(req, res, raw, fixtures, journal, defaults, setCorsHeaders)) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - try { - res.write(`event: error\ndata: ${JSON.stringify({ error: { message: msg } })}\n\n`); - } catch (writeErr) { - logger.debug("Failed to write error recovery response:", writeErr); - } - res.end(); + try { + const raw = await readBody(req); + await handleMessages(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + try { + res.write(`event: error\ndata: ${JSON.stringify({ error: { message: msg } })}\n\n`); + } catch (writeErr) { + logger.debug("Failed to write error recovery response:", writeErr); } - }); + res.end(); + } + } return; } // POST /v2/chat — Cohere v2 Chat API if (pathname === COHERE_CHAT_PATH && req.method === "POST") { - readBody(req) - .then((raw) => handleCohere(req, res, raw, fixtures, journal, defaults, setCorsHeaders)) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - try { - res.write(`event: error\ndata: ${JSON.stringify({ error: { message: msg } })}\n\n`); - } catch (writeErr) { - logger.debug("Failed to write error recovery response:", writeErr); - } - res.end(); + try { + const raw = await readBody(req); + await handleCohere(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + try { + res.write(`event: error\ndata: ${JSON.stringify({ error: { message: msg } })}\n\n`); + } catch (writeErr) { + logger.debug("Failed to write error recovery response:", writeErr); } - }); + res.end(); + } + } return; } // POST /v1/embeddings — OpenAI Embeddings API if (pathname === EMBEDDINGS_PATH && req.method === "POST") { - const deploymentId = azureDeploymentId; - readBody(req) - .then((raw) => { - // Azure deployments may omit model from body — use deployment ID as fallback - if (deploymentId) { - try { - const parsed = JSON.parse(raw) as Record; - if (!parsed.model) { - parsed.model = deploymentId; - return handleEmbeddings( - req, - res, - JSON.stringify(parsed), - fixtures, - journal, - defaults, - setCorsHeaders, - ); - } - } catch { - // Fall through — let handleEmbeddings report the parse error + try { + const deploymentId = azureDeploymentId; + const embeddingsProvider: RecordProviderKey = azureDeploymentId ? "azure" : "openai"; + let raw = await readBody(req); + // Azure deployments may omit model from body — use deployment ID as fallback + if (deploymentId) { + try { + const parsed = JSON.parse(raw) as Record; + if (!parsed.model) { + parsed.model = deploymentId; + raw = JSON.stringify(parsed); } + } catch { + // Fall through — let handleEmbeddings report the parse error } - return handleEmbeddings(req, res, raw, fixtures, journal, defaults, setCorsHeaders); - }) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + } + await handleEmbeddings( + req, + res, + raw, + fixtures, + journal, + defaults, + setCorsHeaders, + embeddingsProvider, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /v1/images/generations — OpenAI Image Generation API if (pathname === IMAGES_PATH && req.method === "POST") { - readBody(req) - .then((raw) => handleImages(req, res, raw, fixtures, journal, defaults, setCorsHeaders)) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + try { + const raw = await readBody(req); + await handleImages(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /v1/audio/speech — OpenAI TTS API if (pathname === SPEECH_PATH && req.method === "POST") { - readBody(req) - .then((raw) => handleSpeech(req, res, raw, fixtures, journal, defaults, setCorsHeaders)) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + try { + const raw = await readBody(req); + await handleSpeech(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /v1/audio/transcriptions — OpenAI Transcription API if (pathname === TRANSCRIPTIONS_PATH && req.method === "POST") { - readBody(req) - .then((raw) => - handleTranscription(req, res, raw, fixtures, journal, defaults, setCorsHeaders), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + try { + const raw = await readBody(req); + await handleTranscription(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /v1/videos — Video Generation API if (pathname === VIDEOS_PATH && req.method === "POST") { - readBody(req) - .then((raw) => - handleVideoCreate( - req, + try { + const raw = await readBody(req); + await handleVideoCreate( + req, + res, + raw, + fixtures, + journal, + defaults, + setCorsHeaders, + videoStates, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - fixtures, - journal, - defaults, - setCorsHeaders, - videoStates, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1197,7 +1188,7 @@ export async function createServer( const videoStatusMatch = pathname.match(VIDEOS_STATUS_RE); if (videoStatusMatch && req.method === "GET") { const videoId = videoStatusMatch[1]; - handleVideoStatus(req, res, videoId, journal, defaults, setCorsHeaders, videoStates); + handleVideoStatus(req, res, videoId, journal, setCorsHeaders, videoStates); return; } @@ -1205,32 +1196,31 @@ export async function createServer( const geminiPredictMatch = pathname.match(GEMINI_PREDICT_RE); if (geminiPredictMatch && req.method === "POST") { const predictModel = geminiPredictMatch[1]; - readBody(req) - .then((raw) => - handleImages( - req, + try { + const raw = await readBody(req); + await handleImages( + req, + res, + raw, + fixtures, + journal, + defaults, + setCorsHeaders, + "gemini", + predictModel, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - fixtures, - journal, - defaults, - setCorsHeaders, - "gemini", - predictModel, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1239,37 +1229,36 @@ export async function createServer( if (geminiMatch && req.method === "POST") { const geminiModel = geminiMatch[1]; const streaming = geminiMatch[2] === "streamGenerateContent"; - readBody(req) - .then((raw) => - handleGemini( - req, + try { + const raw = await readBody(req); + await handleGemini( + req, + res, + raw, + geminiModel, + streaming, + fixtures, + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - geminiModel, - streaming, - fixtures, - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - try { - res.write(`data: ${JSON.stringify({ error: { message: msg } })}\n\n`); - } catch (writeErr) { - logger.debug("Failed to write error recovery response:", writeErr); - } - res.end(); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + try { + res.write(`data: ${JSON.stringify({ error: { message: msg } })}\n\n`); + } catch (writeErr) { + logger.debug("Failed to write error recovery response:", writeErr); } - }); + res.end(); + } + } return; } @@ -1278,38 +1267,37 @@ export async function createServer( if (vertexMatch && req.method === "POST") { const vertexModel = vertexMatch[1]; const streaming = vertexMatch[2] === "streamGenerateContent"; - readBody(req) - .then((raw) => - handleGemini( - req, + try { + const raw = await readBody(req); + await handleGemini( + req, + res, + raw, + vertexModel, + streaming, + fixtures, + journal, + defaults, + setCorsHeaders, + "vertexai", + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - vertexModel, - streaming, - fixtures, - journal, - defaults, - setCorsHeaders, - "vertexai", - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - try { - res.write(`data: ${JSON.stringify({ error: { message: msg } })}\n\n`); - } catch (writeErr) { - logger.debug("Failed to write error recovery response:", writeErr); - } - res.end(); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + try { + res.write(`data: ${JSON.stringify({ error: { message: msg } })}\n\n`); + } catch (writeErr) { + logger.debug("Failed to write error recovery response:", writeErr); } - }); + res.end(); + } + } return; } @@ -1317,22 +1305,30 @@ export async function createServer( const bedrockMatch = pathname.match(BEDROCK_INVOKE_RE); if (bedrockMatch && req.method === "POST") { const bedrockModelId = bedrockMatch[1]; - readBody(req) - .then((raw) => - handleBedrock(req, res, raw, bedrockModelId, fixtures, journal, defaults, setCorsHeaders), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + try { + const raw = await readBody(req); + await handleBedrock( + req, + res, + raw, + bedrockModelId, + fixtures, + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1340,31 +1336,30 @@ export async function createServer( const bedrockStreamMatch = pathname.match(BEDROCK_STREAM_RE); if (bedrockStreamMatch && req.method === "POST") { const bedrockModelId = bedrockStreamMatch[1]; - readBody(req) - .then((raw) => - handleBedrockStream( - req, + try { + const raw = await readBody(req); + await handleBedrockStream( + req, + res, + raw, + bedrockModelId, + fixtures, + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - bedrockModelId, - fixtures, - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1372,31 +1367,30 @@ export async function createServer( const converseMatch = pathname.match(BEDROCK_CONVERSE_RE); if (converseMatch && req.method === "POST") { const converseModelId = converseMatch[1]; - readBody(req) - .then((raw) => - handleConverse( - req, + try { + const raw = await readBody(req); + await handleConverse( + req, + res, + raw, + converseModelId, + fixtures, + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - converseModelId, - fixtures, - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1404,71 +1398,70 @@ export async function createServer( const converseStreamMatch = pathname.match(BEDROCK_CONVERSE_STREAM_RE); if (converseStreamMatch && req.method === "POST") { const converseStreamModelId = converseStreamMatch[1]; - readBody(req) - .then((raw) => - handleConverseStream( - req, + try { + const raw = await readBody(req); + await handleConverseStream( + req, + res, + raw, + converseStreamModelId, + fixtures, + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - converseStreamModelId, - fixtures, - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /api/chat — Ollama Chat API if (pathname === OLLAMA_CHAT_PATH && req.method === "POST") { - readBody(req) - .then((raw) => handleOllama(req, res, raw, fixtures, journal, defaults, setCorsHeaders)) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + try { + const raw = await readBody(req); + await handleOllama(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /api/generate — Ollama Generate API if (pathname === OLLAMA_GENERATE_PATH && req.method === "POST") { - readBody(req) - .then((raw) => - handleOllamaGenerate(req, res, raw, fixtures, journal, defaults, setCorsHeaders), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + try { + const raw = await readBody(req); + await handleOllamaGenerate(req, res, raw, fixtures, journal, defaults, setCorsHeaders); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( + res, + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1497,88 +1490,85 @@ export async function createServer( // POST /search — Web Search API (Tavily-compatible) if (pathname === SEARCH_PATH && req.method === "POST") { - readBody(req) - .then((raw) => - handleSearch( - req, + try { + const raw = await readBody(req); + await handleSearch( + req, + res, + raw, + serviceFixtures?.search ?? [], + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - serviceFixtures?.search ?? [], - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /v2/rerank — Reranking API (Cohere rerank-compatible) if (pathname === RERANK_PATH && req.method === "POST") { - readBody(req) - .then((raw) => - handleRerank( - req, + try { + const raw = await readBody(req); + await handleRerank( + req, + res, + raw, + serviceFixtures?.rerank ?? [], + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - serviceFixtures?.rerank ?? [], - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } // POST /v1/moderations — Moderation API (OpenAI-compatible) if (pathname === MODERATIONS_PATH && req.method === "POST") { - readBody(req) - .then((raw) => - handleModeration( - req, + try { + const raw = await readBody(req); + await handleModeration( + req, + res, + raw, + serviceFixtures?.moderation ?? [], + journal, + defaults, + setCorsHeaders, + ); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : "Internal error"; + if (!res.headersSent) { + writeErrorResponse( res, - raw, - serviceFixtures?.moderation ?? [], - journal, - defaults, - setCorsHeaders, - ), - ) - .catch((err: unknown) => { - const msg = err instanceof Error ? err.message : "Internal error"; - if (!res.headersSent) { - writeErrorResponse( - res, - 500, - JSON.stringify({ error: { message: msg, type: "server_error" } }), - ); - } else if (!res.writableEnded) { - res.destroy(); - } - }); + 500, + JSON.stringify({ error: { message: msg, type: "server_error" } }), + ); + } else if (!res.writableEnded) { + res.destroy(); + } + } return; } @@ -1593,15 +1583,17 @@ export async function createServer( } const completionsProvider: RecordProviderKey = azureDeploymentId ? "azure" : "openai"; - handleCompletions( - req, - res, - fixtures, - journal, - defaults, - azureDeploymentId, - completionsProvider, - ).catch((err: unknown) => { + try { + await handleCompletions( + req, + res, + fixtures, + journal, + defaults, + azureDeploymentId, + completionsProvider, + ); + } catch (err: unknown) { const msg = err instanceof Error ? err.message : "Internal error"; if (!res.headersSent) { writeErrorResponse( @@ -1625,7 +1617,7 @@ export async function createServer( } res.end(); } - }); + } } // ─── WebSocket upgrade handling ────────────────────────────────────────── diff --git a/src/types.ts b/src/types.ts index 3245c0a..e434f37 100644 --- a/src/types.ts +++ b/src/types.ts @@ -196,15 +196,6 @@ export interface StreamingProfile { jitter?: number; // Random variance factor (0-1), default 0 } -/** - * Probabilistic chaos injection rates. - * - * Rates are evaluated sequentially per request — drop → malformed → disconnect - * — and the first hit wins. Consequently malformedRate is conditional on drop - * not firing, and disconnectRate is conditional on neither drop nor malformed - * firing. A config of `{ dropRate: 0.5, malformedRate: 0.5 }` yields a ~25 % - * effective malformed rate, not 50 %. - */ export interface ChaosConfig { dropRate?: number; malformedRate?: number; @@ -311,16 +302,10 @@ export interface JournalEntry { response: { status: number; fixture: Fixture | null; + source?: "fixture" | "proxy"; interrupted?: boolean; interruptReason?: string; chaosAction?: ChaosAction; - /** - * What was going to serve this request. "fixture" = a fixture matched (or - * would have, before chaos intervened). "proxy" = no fixture matched and - * proxy was configured. Absent when the distinction doesn't apply (e.g. - * 404/503 fallback where nothing was going to serve). - */ - source?: "fixture" | "proxy"; }; } From 31b2436857df03fa0f88b5322ad3e0e163873261 Mon Sep 17 00:00:00 2001 From: Jordan Ritter Date: Thu, 23 Apr 2026 16:23:06 -0700 Subject: [PATCH 2/6] fix: recorder crash hardening and recording fidelity headersSent guards, clientDisconnected tracking, content+toolCalls preservation, Cohere v2 detection, tool-call ID extraction (5 providers), reasoning extraction (4 providers), filter+join multi-block text, thinking-only/empty-content handling, Ollama /api/generate detection, streaming collapse reasoning propagation. --- src/__tests__/multimedia-record.test.ts | 4 +- src/__tests__/proxy-only.test.ts | 221 ------------ src/__tests__/recorder.test.ts | 221 ++++++++---- src/recorder.ts | 439 ++++++++++++++++-------- 4 files changed, 448 insertions(+), 437 deletions(-) diff --git a/src/__tests__/multimedia-record.test.ts b/src/__tests__/multimedia-record.test.ts index 4d915ba..9f28970 100644 --- a/src/__tests__/multimedia-record.test.ts +++ b/src/__tests__/multimedia-record.test.ts @@ -105,7 +105,7 @@ describe("multimedia record: image response detection", () => { }; const { req, res } = createMockReqRes("/v1/images/generations"); - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, request, @@ -115,7 +115,7 @@ describe("multimedia record: image response detection", () => { { record, logger }, ); - expect(outcome).toBe("relayed"); + expect(proxied).toBe(true); expect(fixtures).toHaveLength(1); const fixture = fixtures[0]; expect(fixture.match.endpoint).toBe("image"); diff --git a/src/__tests__/proxy-only.test.ts b/src/__tests__/proxy-only.test.ts index 7645638..37eef41 100644 --- a/src/__tests__/proxy-only.test.ts +++ b/src/__tests__/proxy-only.test.ts @@ -225,227 +225,6 @@ describe("proxy-only mode", () => { await new Promise((resolve) => countingUpstream.server.close(() => resolve())); }); - it("applies chaos BEFORE proxying (drop)", async () => { - const countingUpstream = await createCountingUpstream("should not be hit"); - - recorder = await createServer([], { - port: 0, - chaos: { dropRate: 1.0 }, - record: { - providers: { openai: countingUpstream.url }, - fixturePath: (tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-chaos-proxy-"))), - proxyOnly: true, - }, - }); - - const resp = await post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST); - - expect(resp.status).toBe(500); - expect(countingUpstream.getCount()).toBe(0); - - await new Promise((resolve) => countingUpstream.server.close(() => resolve())); - }); - - it("applies chaos BEFORE proxying (disconnect)", async () => { - const countingUpstream = await createCountingUpstream("should not be hit"); - - recorder = await createServer([], { - port: 0, - chaos: { disconnectRate: 1.0 }, - record: { - providers: { openai: countingUpstream.url }, - fixturePath: (tmpDir = fs.mkdtempSync( - path.join(os.tmpdir(), "aimock-chaos-proxy-disconnect-"), - )), - proxyOnly: true, - }, - }); - - await expect(post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST)).rejects.toThrow(); - - expect(countingUpstream.getCount()).toBe(0); - - await new Promise((resolve) => countingUpstream.server.close(() => resolve())); - }); - - it("applies malformed chaos AFTER proxying (upstream called, body corrupted, journaled)", async () => { - const countingUpstream = await createCountingUpstream("valid content"); - - recorder = await createServer([], { - port: 0, - chaos: { malformedRate: 1.0 }, - record: { - providers: { openai: countingUpstream.url }, - fixturePath: (tmpDir = fs.mkdtempSync( - path.join(os.tmpdir(), "aimock-chaos-postresponse-"), - )), - proxyOnly: true, - }, - }); - - const resp = await post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST); - - // Upstream IS called: malformed is a post-response mutation, not a pre-flight drop - expect(countingUpstream.getCount()).toBe(1); - // Client sees 200 with a body that does NOT parse as JSON - expect(resp.status).toBe(200); - expect(() => JSON.parse(resp.body)).toThrow(); - // Journal records the chaos action exactly once (no double-entry from the - // chaos path + the default proxy-relay path) - expect(recorder.journal.size).toBe(1); - const last = recorder.journal.getLast(); - expect(last?.response.chaosAction).toBe("malformed"); - expect(last?.response.fixture).toBeNull(); - - await new Promise((resolve) => countingUpstream.server.close(() => resolve())); - }); - - it("preserves upstream content-type on replay when no chaos fires", async () => { - const countingUpstream = await createCountingUpstream("valid content"); - - recorder = await createServer([], { - port: 0, - record: { - providers: { openai: countingUpstream.url }, - fixturePath: (tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-proxy-ct-"))), - proxyOnly: true, - }, - }); - - const resp = await post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST); - - expect(resp.status).toBe(200); - const ct = resp.headers["content-type"]; - expect(typeof ct === "string" ? ct : "").toContain("application/json"); - // Body is valid JSON and round-trips - expect(JSON.parse(resp.body).choices[0].message.content).toBe("valid content"); - - await new Promise((resolve) => countingUpstream.server.close(() => resolve())); - }); - - it("proxy failure produces 502 end-to-end and journals the failure", async () => { - // Integration test: unit tests prove recorder.ts writes 502 on upstream - // failure; this pins that handleCompletions handles the "relayed" outcome - // correctly (journals, doesn't hang). - recorder = await createServer([], { - port: 0, - record: { - providers: { openai: "http://127.0.0.1:1" }, // port 1 — unreachable - fixturePath: (tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-proxy-fail-"))), - proxyOnly: true, - }, - }); - - const resp = await post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST); - - expect(resp.status).toBe(502); - expect(recorder.journal.size).toBe(1); - const entry = recorder.journal.getLast(); - expect(entry?.response.status).toBe(502); - expect(entry?.response.fixture).toBeNull(); - expect(entry?.response.source).toBe("proxy"); - expect(entry?.response.chaosAction).toBeUndefined(); - }); - - it("chaos + proxy failure: malformed was rolled but upstream failed → 502, no chaosAction", async () => { - // Integration test: when chaos rolls malformed but the upstream request - // fails, proxyAndRecord synthesizes a 502 before the hook is invoked. The - // journal should reflect what actually happened (502, no chaos) rather - // than what was intended. - recorder = await createServer([], { - port: 0, - chaos: { malformedRate: 1.0 }, - record: { - providers: { openai: "http://127.0.0.1:1" }, // unreachable - fixturePath: (tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-chaos-proxy-fail-"))), - proxyOnly: true, - }, - }); - - const resp = await post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST); - - // Client sees the proxy failure, NOT a malformed-JSON body - expect(resp.status).toBe(502); - expect(() => JSON.parse(resp.body)).not.toThrow(); - - expect(recorder.journal.size).toBe(1); - const entry = recorder.journal.getLast(); - expect(entry?.response.status).toBe(502); - expect(entry?.response.source).toBe("proxy"); - // Chaos was rolled but never applied — journal must not claim it fired - expect(entry?.response.chaosAction).toBeUndefined(); - }); - - it("SSE upstream bypasses malformed chaos: body intact, bypass counted, journal clean", async () => { - // Pins the one place chaos silently no-ops: when upstream streams SSE, - // the bytes are already on the wire before the chaos hook could fire. - // Without an explicit bypass signal, malformedRate: 1.0 on SSE traffic - // would silently mean 0% corruption with no log, metric, or journal - // trace. Lifting the gate out of recorder.ts in a future refactor - // (phase 3: streaming mutation) should trip this test. - const sseUpstream = await new Promise<{ server: http.Server; url: string }>((resolve) => { - const server = http.createServer((_req, res) => { - res.writeHead(200, { "Content-Type": "text/event-stream" }); - res.write(`data: ${JSON.stringify({ choices: [{ delta: { content: "hi" } }] })}\n\n`); - res.write("data: [DONE]\n\n"); - res.end(); - }); - server.listen(0, "127.0.0.1", () => { - const { port } = server.address() as { port: number }; - resolve({ server, url: `http://127.0.0.1:${port}` }); - }); - }); - - recorder = await createServer([], { - port: 0, - metrics: true, - chaos: { malformedRate: 1.0 }, - record: { - providers: { openai: sseUpstream.url }, - fixturePath: (tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-chaos-sse-"))), - proxyOnly: true, - }, - }); - - const resp = await post(`${recorder.url}/v1/chat/completions`, CHAT_REQUEST); - - // Client receives a real SSE stream — content-type and frames intact, not - // the malformed-JSON sentinel. - expect(resp.status).toBe(200); - const ct = resp.headers["content-type"]; - expect(typeof ct === "string" ? ct : "").toContain("text/event-stream"); - expect(resp.body).toContain("data: "); - expect(resp.body).not.toContain("{malformed json"); - - // Journal records the relayed proxy call, NOT a chaos action — the - // chaos roll happened but couldn't be applied, so claiming it fired - // would be a lie to the observer. - expect(recorder.journal.size).toBe(1); - const last = recorder.journal.getLast(); - expect(last?.response.chaosAction).toBeUndefined(); - expect(last?.response.source).toBe("proxy"); - - // Bypass must be visible in metrics so operators can see that a - // configured chaos action didn't fire. - const metricsRes = await new Promise<{ body: string }>((resolve, reject) => { - const mReq = http.request(`${recorder!.url}/metrics`, { method: "GET" }, (mRes) => { - const chunks: Buffer[] = []; - mRes.on("data", (c: Buffer) => chunks.push(c)); - mRes.on("end", () => resolve({ body: Buffer.concat(chunks).toString() })); - }); - mReq.on("error", reject); - mReq.end(); - }); - expect(metricsRes.body).toMatch( - /aimock_chaos_bypassed_total\{[^}]*action="malformed"[^}]*source="proxy"[^}]*\} 1/, - ); - // Paired negative: the normal chaos_triggered counter must NOT increment - // for a bypass — the action didn't actually fire. - expect(metricsRes.body).not.toMatch(/aimock_chaos_triggered_total\{[^}]*action="malformed"/); - - await new Promise((resolve) => sseUpstream.server.close(() => resolve())); - }); - it("regular record mode DOES cache in memory — second request served from cache", async () => { // Use a counting upstream to verify only the first request is proxied const countingUpstream = await createCountingUpstream("cached response"); diff --git a/src/__tests__/recorder.test.ts b/src/__tests__/recorder.test.ts index 0eb2a47..eaadc7b 100644 --- a/src/__tests__/recorder.test.ts +++ b/src/__tests__/recorder.test.ts @@ -5,7 +5,7 @@ import * as os from "node:os"; import * as path from "node:path"; import type { Fixture, FixtureFile } from "../types.js"; import { createServer, type ServerInstance } from "../server.js"; -import { proxyAndRecord, type ProxyCapturedResponse } from "../recorder.js"; +import { proxyAndRecord } from "../recorder.js"; import type { RecordConfig } from "../types.js"; import { Logger } from "../logger.js"; import { LLMock } from "../llmock.js"; @@ -110,13 +110,13 @@ afterEach(async () => { // --------------------------------------------------------------------------- describe("proxyAndRecord", () => { - it('returns "not_configured" when provider is not configured', async () => { + it("returns false when provider is not configured", async () => { const fixtures: Fixture[] = []; const logger = new Logger("silent"); const record: RecordConfig = { providers: {} }; // Create a mock req/res pair — we just need them to exist, - // proxyAndRecord should short-circuit before using them + // proxyAndRecord should return false before using them const { req, res } = createMockReqRes(); const result = await proxyAndRecord( @@ -129,10 +129,10 @@ describe("proxyAndRecord", () => { { record, logger }, ); - expect(result).toBe("not_configured"); + expect(result).toBe(false); }); - it('returns "not_configured" when record config is undefined', async () => { + it("returns false when record config is undefined", async () => { const fixtures: Fixture[] = []; const logger = new Logger("silent"); @@ -148,71 +148,7 @@ describe("proxyAndRecord", () => { { record: undefined, logger }, ); - expect(result).toBe("not_configured"); - }); - - it("beforeWriteResponse hook receives raw upstream bytes (binary-safe)", async () => { - // Pins the refactor's claim that the hook sees raw upstream bytes, not a - // UTF-8-decoded-then-re-encoded view. Uses a deliberately non-UTF8 byte - // sequence so any round-trip through String() would corrupt it. - const bytes = Buffer.from([0xff, 0xfe, 0xfd, 0x00, 0x01, 0x02, 0x7f, 0x80]); - - const binaryUpstream = http.createServer((_upReq, upRes) => { - upRes.writeHead(200, { "Content-Type": "application/octet-stream" }); - upRes.end(bytes); - }); - await new Promise((resolve) => binaryUpstream.listen(0, "127.0.0.1", () => resolve())); - const upstreamPort = (binaryUpstream.address() as { port: number }).port; - - let captured: ProxyCapturedResponse | undefined; - - // Minimal HTTP server that invokes proxyAndRecord with our capture hook, - // so req/res are real and the full recorder pipeline exercises the hook. - const recorderServer = http.createServer((req, res) => { - const chunks: Buffer[] = []; - req.on("data", (c: Buffer) => chunks.push(c)); - req.on("end", async () => { - const rawBody = Buffer.concat(chunks).toString(); - await proxyAndRecord( - req, - res, - JSON.parse(rawBody), - "openai", - "/v1/chat/completions", - [], - { - record: { - providers: { openai: `http://127.0.0.1:${upstreamPort}` }, - proxyOnly: true, - }, - logger: new Logger("silent"), - }, - rawBody, - { - beforeWriteResponse: (response) => { - captured = response; - return false; // let the default relay proceed; we only wanted to observe - }, - }, - ); - }); - }); - await new Promise((resolve) => recorderServer.listen(0, "127.0.0.1", () => resolve())); - const recorderPort = (recorderServer.address() as { port: number }).port; - - try { - await post(`http://127.0.0.1:${recorderPort}/v1/chat/completions`, { - model: "gpt-4", - messages: [{ role: "user", content: "hi" }], - }); - - expect(captured).toBeDefined(); - expect(captured!.body).toBeInstanceOf(Buffer); - expect(Buffer.compare(captured!.body, bytes)).toBe(0); - } finally { - await new Promise((resolve) => binaryUpstream.close(() => resolve())); - await new Promise((resolve) => recorderServer.close(() => resolve())); - } + expect(result).toBe(false); }); }); @@ -2361,6 +2297,151 @@ describe("buildFixtureResponse format detection", () => { expect(fixtureContent.fixtures[0].response.content).toBeUndefined(); }); + it("detects Cohere v2 message-level tool_calls with text content", async () => { + const { url: upstreamUrl } = await createRawUpstreamWithStatus({ + finish_reason: "TOOL_CALL", + message: { + role: "assistant", + content: [{ type: "text", text: "Let me look that up." }], + tool_calls: [ + { + name: "get_weather", + parameters: { city: "SF" }, + }, + ], + }, + }); + + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-record-")); + recorder = await createServer([], { + port: 0, + record: { providers: { cohere: upstreamUrl }, fixturePath: tmpDir }, + }); + + const resp = await post(`${recorder.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "cohere v2 msg tool_calls with text" }], + stream: false, + }); + + expect(resp.status).toBe(200); + + const files = fs.readdirSync(tmpDir); + const fixtureFiles = files.filter((f) => f.endsWith(".json")); + expect(fixtureFiles).toHaveLength(1); + + const fixtureContent = JSON.parse( + fs.readFileSync(path.join(tmpDir, fixtureFiles[0]), "utf-8"), + ) as { + fixtures: Array<{ + response: { + content?: string; + toolCalls?: Array<{ name: string; arguments: string }>; + }; + }>; + }; + expect(fixtureContent.fixtures[0].response.content).toBe("Let me look that up."); + expect(fixtureContent.fixtures[0].response.toolCalls).toBeDefined(); + expect(fixtureContent.fixtures[0].response.toolCalls).toHaveLength(1); + expect(fixtureContent.fixtures[0].response.toolCalls![0].name).toBe("get_weather"); + expect(JSON.parse(fixtureContent.fixtures[0].response.toolCalls![0].arguments)).toEqual({ + city: "SF", + }); + }); + + it("detects Cohere v2 message-level tool_calls without text content", async () => { + const { url: upstreamUrl } = await createRawUpstreamWithStatus({ + finish_reason: "TOOL_CALL", + message: { + role: "assistant", + content: [], + tool_calls: [ + { + name: "search_docs", + parameters: { query: "aimock" }, + }, + ], + }, + }); + + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-record-")); + recorder = await createServer([], { + port: 0, + record: { providers: { cohere: upstreamUrl }, fixturePath: tmpDir }, + }); + + const resp = await post(`${recorder.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "cohere v2 msg tool_calls only" }], + stream: false, + }); + + expect(resp.status).toBe(200); + + const files = fs.readdirSync(tmpDir); + const fixtureFiles = files.filter((f) => f.endsWith(".json")); + expect(fixtureFiles).toHaveLength(1); + + const fixtureContent = JSON.parse( + fs.readFileSync(path.join(tmpDir, fixtureFiles[0]), "utf-8"), + ) as { + fixtures: Array<{ + response: { + content?: string; + toolCalls?: Array<{ name: string; arguments: string }>; + }; + }>; + }; + expect(fixtureContent.fixtures[0].response.content).toBeUndefined(); + expect(fixtureContent.fixtures[0].response.toolCalls).toBeDefined(); + expect(fixtureContent.fixtures[0].response.toolCalls).toHaveLength(1); + expect(fixtureContent.fixtures[0].response.toolCalls![0].name).toBe("search_docs"); + expect(JSON.parse(fixtureContent.fixtures[0].response.toolCalls![0].arguments)).toEqual({ + query: "aimock", + }); + }); + + it("detects Cohere v2 text-only response (no message-level tool_calls)", async () => { + const { url: upstreamUrl } = await createRawUpstreamWithStatus({ + finish_reason: "COMPLETE", + message: { + role: "assistant", + content: [{ type: "text", text: "Hello from Cohere" }], + }, + }); + + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "aimock-record-")); + recorder = await createServer([], { + port: 0, + record: { providers: { cohere: upstreamUrl }, fixturePath: tmpDir }, + }); + + const resp = await post(`${recorder.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "cohere v2 text only" }], + stream: false, + }); + + expect(resp.status).toBe(200); + + const files = fs.readdirSync(tmpDir); + const fixtureFiles = files.filter((f) => f.endsWith(".json")); + expect(fixtureFiles).toHaveLength(1); + + const fixtureContent = JSON.parse( + fs.readFileSync(path.join(tmpDir, fixtureFiles[0]), "utf-8"), + ) as { + fixtures: Array<{ + response: { + content?: string; + toolCalls?: Array<{ name: string; arguments: string }>; + }; + }>; + }; + expect(fixtureContent.fixtures[0].response.content).toBe("Hello from Cohere"); + expect(fixtureContent.fixtures[0].response.toolCalls).toBeUndefined(); + }); + it("unknown format falls back to error response", async () => { const { url: upstreamUrl } = await createRawUpstreamWithStatus({ custom: "data", diff --git a/src/recorder.ts b/src/recorder.ts index fb841c6..229a389 100644 --- a/src/recorder.ts +++ b/src/recorder.ts @@ -36,62 +36,13 @@ const STRIP_HEADERS = new Set([ "accept-encoding", ]); -/** - * Captured upstream response, exposed to the `beforeWriteResponse` hook so - * callers can decide whether to relay it or mutate it (e.g. chaos injection). - */ -export interface ProxyCapturedResponse { - status: number; - contentType: string; - body: Buffer; -} - -export interface ProxyOptions { - /** - * Called after the upstream response has been captured and recorded, but - * before the relay to the client. Contract when the hook returns `true`: - * 1. It wrote its own response body on `res`. - * 2. It journaled the outcome (proxyAndRecord will NOT journal it). - * 3. proxyAndRecord skips its default relay and returns `"handled_by_hook"`. - * - * Returning `false` (or omitting the hook) lets proxyAndRecord relay the - * upstream response normally and leaves journaling to the caller via the - * `"relayed"` outcome. Rejected promises propagate and leave the response - * unwritten. - * - * NOT invoked when the upstream response was streamed progressively to the - * client (SSE) — the bytes are already on the wire and can't be mutated. - * Callers that need to observe the bypass should pass `onHookBypassed`. - */ - beforeWriteResponse?: (response: ProxyCapturedResponse) => boolean | Promise; - /** - * Called when `beforeWriteResponse` was provided but could not be invoked - * because the upstream response was streamed to the client progressively. - * The hook was rolled + wired but the bytes left before it could fire. - * Intended for observability (log/metric/journal annotation) — proxyAndRecord - * still returns `"relayed"`. - */ - onHookBypassed?: (reason: "sse_streamed") => void; -} - -/** - * Outcome of a proxyAndRecord call, returned so the caller can decide whether - * to journal, fall through, or stop — without sharing a mutable flag with the - * `beforeWriteResponse` hook. - * - * - `"not_configured"` — no upstream URL for this provider; caller should fall - * through to its next branch (typically strict/404). - * - `"relayed"` — the default code path wrote a response (upstream success or - * synthesized 502 error). Caller should journal the outcome. - * - `"handled_by_hook"` — the hook wrote + journaled its own response. Caller - * should not double-journal. - */ -export type ProxyOutcome = "not_configured" | "relayed" | "handled_by_hook"; - /** * Proxy an unmatched request to the real upstream provider, record the * response as a fixture on disk and in memory, then relay the response * back to the original client. + * + * Returns `true` if the request was proxied (provider configured), + * `false` if no upstream URL is configured for the given provider key. */ export async function proxyAndRecord( req: http.IncomingMessage, @@ -106,17 +57,16 @@ export async function proxyAndRecord( requestTransform?: (req: ChatCompletionRequest) => ChatCompletionRequest; }, rawBody?: string, - options?: ProxyOptions, -): Promise { +): Promise { const record = defaults.record; - if (!record) return "not_configured"; + if (!record) return false; const providers = record.providers; const upstreamUrl = providers[providerKey]; if (!upstreamUrl) { defaults.logger.warn(`No upstream URL configured for provider "${providerKey}" — cannot proxy`); - return "not_configured"; + return false; } const fixturePath = record.fixturePath ?? "./fixtures/recorded"; @@ -132,7 +82,7 @@ export async function proxyAndRecord( error: { message: `Invalid upstream URL: ${upstreamUrl}`, type: "proxy_error" }, }), ); - return "relayed"; + return true; } defaults.logger.warn(`NO FIXTURE MATCH — proxying to ${upstreamUrl}${pathname}`); @@ -156,6 +106,7 @@ export async function proxyAndRecord( // Track whether we streamed SSE progressively to the client; if so, // skip the final res.writeHead/res.end relay at the bottom of this fn. let streamedToClient = false; + let clientDisconnected = false; try { const result = await makeUpstreamRequest(target, forwardHeaders, requestBody, res); upstreamStatus = result.status; @@ -163,23 +114,25 @@ export async function proxyAndRecord( upstreamBody = result.body; rawBuffer = result.rawBuffer; streamedToClient = result.streamedToClient; + clientDisconnected = result.clientDisconnected; } catch (err) { const msg = err instanceof Error ? err.message : "Unknown proxy error"; defaults.logger.error(`Proxy request failed: ${msg}`); - res.writeHead(502, { "Content-Type": "application/json" }); - res.end( - JSON.stringify({ - error: { message: `Proxy to upstream failed: ${msg}`, type: "proxy_error" }, - }), - ); - return "relayed"; + if (!res.headersSent) { + res.writeHead(502, { "Content-Type": "application/json" }); + res.end( + JSON.stringify({ + error: { message: `Proxy to upstream failed: ${msg}`, type: "proxy_error" }, + }), + ); + } else { + // SSE headers already sent — gracefully close the connection + res.end(); + } + return true; } - // Detect streaming response and collapse if necessary. - // NOTE: collapse buffers the entire upstream body in memory. Fine for - // current chat-completions traffic (responses are small), but revisit if - // this path ever proxies long-lived or large streams — both the buffer - // here and the hook below receive the full payload. + // Detect streaming response and collapse if necessary const contentType = upstreamHeaders["content-type"]; const ctString = Array.isArray(contentType) ? contentType.join(", ") : (contentType ?? ""); const isBinaryStream = ctString.toLowerCase().includes("application/vnd.amazon.eventstream"); @@ -218,15 +171,20 @@ export async function proxyAndRecord( if (collapsed.content === "" && (!collapsed.toolCalls || collapsed.toolCalls.length === 0)) { defaults.logger.warn("Stream collapse produced empty content — fixture may be incomplete"); } + const reasoningSpread = collapsed.reasoning ? { reasoning: collapsed.reasoning } : {}; if (collapsed.toolCalls && collapsed.toolCalls.length > 0) { if (collapsed.content) { - defaults.logger.warn( - "Collapsed response has both content and toolCalls — preferring toolCalls", - ); + // Both content and toolCalls present — save as ContentWithToolCallsResponse + fixtureResponse = { + content: collapsed.content, + toolCalls: collapsed.toolCalls, + ...reasoningSpread, + }; + } else { + fixtureResponse = { toolCalls: collapsed.toolCalls, ...reasoningSpread }; } - fixtureResponse = { toolCalls: collapsed.toolCalls }; } else { - fixtureResponse = { content: collapsed.content ?? "" }; + fixtureResponse = { content: collapsed.content ?? "", ...reasoningSpread }; } } else { // Non-streaming — try to parse as JSON @@ -240,12 +198,24 @@ export async function proxyAndRecord( let encodingFormat: string | undefined; try { encodingFormat = rawBody ? JSON.parse(rawBody).encoding_format : undefined; - } catch { - /* not JSON */ + } catch (err) { + defaults.logger.debug( + `Could not parse encoding_format from raw body: ${err instanceof Error ? err.message : "unknown error"}`, + ); } fixtureResponse = buildFixtureResponse(parsedResponse, upstreamStatus, encodingFormat); } + // If the client disconnected mid-stream, the collected data is likely + // truncated. Saving a partial fixture is worse than saving none — skip + // fixture persistence entirely. + if (clientDisconnected) { + defaults.logger.warn( + "Client disconnected mid-stream — skipping fixture save to avoid truncated data", + ); + return true; + } + // Build the match criteria from the (optionally transformed) request const matchRequest = defaults.requestTransform ? defaults.requestTransform(request) : request; const fixtureMatch = buildFixtureMatch(matchRequest); @@ -282,10 +252,7 @@ export async function proxyAndRecord( warnings.push("Stream response was truncated — fixture may be incomplete"); } - // Auth headers are forwarded to upstream but excluded from saved fixtures for security. - // NOTE: the persisted fixture is always the real upstream response, even when chaos - // later mutates the relay (e.g. malformed via beforeWriteResponse). Chaos is a live-traffic - // decoration; the recorded artifact must stay truthful so replay sees what upstream said. + // Auth headers are forwarded to upstream but excluded from saved fixtures for security const fileContent: Record = { fixtures: [fixture] }; if (warnings.length > 0) { fileContent._warning = warnings.join("; "); @@ -295,7 +262,11 @@ export async function proxyAndRecord( } catch (err) { const msg = err instanceof Error ? err.message : "Unknown filesystem error"; defaults.logger.error(`Failed to save fixture to disk: ${msg}`); - res.setHeader("X-LLMock-Record-Error", msg); + if (!res.headersSent) { + res.setHeader("X-LLMock-Record-Error", msg); + } else { + defaults.logger.warn(`Cannot set X-LLMock-Record-Error header — headers already sent`); + } } if (writtenToDisk) { @@ -314,35 +285,17 @@ export async function proxyAndRecord( // Relay upstream response to client (skip when SSE was already streamed // progressively by makeUpstreamRequest — headers and body are already on // the wire). - if (streamedToClient) { - // SSE: the hook can't run because the body is already on the wire. Surface - // the bypass so the caller (typically the chaos layer) can record it — - // otherwise a configured chaos action silently no-ops on SSE traffic. - if (options?.beforeWriteResponse && options.onHookBypassed) { - options.onHookBypassed("sse_streamed"); - } - } else { - // Give the caller a chance to mutate or replace the response before relay. - // Used by the chaos layer to turn a successful proxy into a malformed body. - // `body` is the raw upstream bytes so binary payloads survive round-tripping. - if (options?.beforeWriteResponse) { - const handled = await options.beforeWriteResponse({ - status: upstreamStatus, - contentType: ctString, - body: rawBuffer, - }); - if (handled) return "handled_by_hook"; - } - + if (!streamedToClient) { const relayHeaders: Record = {}; if (ctString) { relayHeaders["Content-Type"] = ctString; } res.writeHead(upstreamStatus, relayHeaders); - res.end(isBinaryStream ? rawBuffer : upstreamBody); + const isAudioRelay = ctString.toLowerCase().startsWith("audio/"); + res.end(isBinaryStream || isAudioRelay ? rawBuffer : upstreamBody); } - return "relayed"; + return true; } // --------------------------------------------------------------------------- @@ -360,6 +313,7 @@ function makeUpstreamRequest( body: string; rawBuffer: Buffer; streamedToClient: boolean; + clientDisconnected: boolean; }> { return new Promise((resolve, reject) => { const transport = target.protocol === "https:" ? https : http; @@ -388,6 +342,7 @@ function makeUpstreamRequest( const ctStr = Array.isArray(ct) ? ct.join(", ") : (ct ?? ""); const isSSE = ctStr.toLowerCase().includes("text/event-stream"); let streamedToClient = false; + let clientDisconnected = false; if (isSSE && clientRes && !clientRes.headersSent) { const relayHeaders: Record = {}; if (ctStr) relayHeaders["Content-Type"] = ctStr; @@ -396,22 +351,44 @@ function makeUpstreamRequest( // before the first data chunk arrives. if (typeof clientRes.flushHeaders === "function") clientRes.flushHeaders(); streamedToClient = true; + // Stop relaying if the client disconnects mid-stream + clientRes.on("close", () => { + clientDisconnected = true; + req.destroy(); + }); } const chunks: Buffer[] = []; res.on("data", (chunk: Buffer) => { chunks.push(chunk); - if (streamedToClient) clientRes!.write(chunk); + if ( + streamedToClient && + clientRes && + !clientDisconnected && + !clientRes.destroyed && + !clientRes.writableEnded + ) { + clientRes.write(chunk); + } }); res.on("error", reject); res.on("end", () => { const rawBuffer = Buffer.concat(chunks); - if (streamedToClient) clientRes!.end(); + if ( + streamedToClient && + clientRes && + !clientDisconnected && + !clientRes.destroyed && + !clientRes.writableEnded + ) { + clientRes.end(); + } resolve({ status: res.statusCode ?? 500, headers: res.headers, body: rawBuffer.toString(), rawBuffer, streamedToClient, + clientDisconnected, }); }); }, @@ -448,8 +425,13 @@ function buildFixtureResponse( const obj = parsed as Record; - // Error response - if (obj.error) { + // Error response — only match the actual { error: { message: "..." } } shape + // used by OpenAI/Anthropic/etc., not arbitrary truthy `.error` fields. + if ( + typeof obj.error === "object" && + obj.error !== null && + typeof (obj.error as Record).message === "string" + ) { const err = obj.error as Record; return { error: { @@ -468,13 +450,10 @@ function buildFixtureResponse( return { embedding: first.embedding as number[] }; } if (typeof first.embedding === "string" && encodingFormat === "base64") { - try { - const buf = Buffer.from(first.embedding, "base64"); - const floats = new Float32Array(buf.buffer, buf.byteOffset, buf.byteLength / 4); - return { embedding: Array.from(floats) }; - } catch { - // Corrupted base64 or non-float32 data — fall through to error - } + const buf = Buffer.from(first.embedding, "base64"); + const aligned = new Uint8Array(buf).buffer; // Always offset 0 + const floats = new Float32Array(aligned, 0, buf.byteLength / 4); + return { embedding: Array.from(floats) }; } // OpenAI image generation: { created, data: [{ url, b64_json, revised_prompt }] } if (first.url || first.b64_json) { @@ -519,10 +498,19 @@ function buildFixtureResponse( } // OpenAI video generation: { id, status, ... } + // Guard against false positives: many API responses have `id` + `status` fields + // (e.g. chat completions, Anthropic messages). Reject if the response has fields + // that indicate a known non-video format. if ( typeof obj.id === "string" && typeof obj.status === "string" && - (obj.status === "completed" || obj.status === "in_progress" || obj.status === "failed") + (obj.status === "completed" || obj.status === "in_progress" || obj.status === "failed") && + !("choices" in obj) && + !("content" in obj) && + !("candidates" in obj) && + !("message" in obj) && + !("data" in obj) && + !("object" in obj) ) { if (obj.status === "completed" && obj.url) { return { @@ -551,42 +539,84 @@ function buildFixtureResponse( const choice = obj.choices[0] as Record; const message = choice.message as Record | undefined; if (message) { - // Tool calls - if (Array.isArray(message.tool_calls) && message.tool_calls.length > 0) { + const hasToolCalls = Array.isArray(message.tool_calls) && message.tool_calls.length > 0; + const hasContent = typeof message.content === "string" && message.content.length > 0; + + const openaiReasoning = + typeof message.reasoning_content === "string" && message.reasoning_content.length > 0 + ? message.reasoning_content + : undefined; + + if (hasToolCalls) { const toolCalls: ToolCall[] = (message.tool_calls as Array>).map( (tc) => { const fn = tc.function as Record; return { name: String(fn.name), arguments: String(fn.arguments), + ...(tc.id ? { id: String(tc.id) } : {}), }; }, ); - return { toolCalls }; + if (hasContent) { + return { + content: message.content as string, + toolCalls, + ...(openaiReasoning ? { reasoning: openaiReasoning } : {}), + }; + } + return { toolCalls, ...(openaiReasoning ? { reasoning: openaiReasoning } : {}) }; } - // Text content - if (typeof message.content === "string") { - return { content: message.content }; + // Text content only + if (hasContent) { + return { + content: message.content as string, + ...(openaiReasoning ? { reasoning: openaiReasoning } : {}), + }; } + // Recognized OpenAI shape but empty content (e.g. content filtering, zero max_tokens) + return { content: "", ...(openaiReasoning ? { reasoning: openaiReasoning } : {}) }; } } // Anthropic: { content: [{ type: "text", text: "..." }] } or tool_use if (Array.isArray(obj.content) && obj.content.length > 0) { const blocks = obj.content as Array>; - // Check for tool_use blocks first const toolUseBlocks = blocks.filter((b) => b.type === "tool_use"); - if (toolUseBlocks.length > 0) { + const textBlocks = blocks.filter((b) => b.type === "text" && typeof b.text === "string"); + const thinkingBlocks = blocks.filter((b) => b.type === "thinking"); + const hasToolCalls = toolUseBlocks.length > 0; + const joinedText = textBlocks.map((b) => String(b.text ?? "")).join(""); + const hasContent = joinedText.length > 0; + const anthropicReasoning = + thinkingBlocks.length > 0 + ? thinkingBlocks.map((b) => String(b.thinking ?? "")).join("") + : undefined; + + if (hasToolCalls) { const toolCalls: ToolCall[] = toolUseBlocks.map((b) => ({ name: String(b.name), arguments: typeof b.input === "string" ? b.input : JSON.stringify(b.input), + ...(b.id ? { id: String(b.id) } : {}), })); - return { toolCalls }; + if (hasContent) { + return { + content: joinedText, + toolCalls, + ...(anthropicReasoning ? { reasoning: anthropicReasoning } : {}), + }; + } + return { toolCalls, ...(anthropicReasoning ? { reasoning: anthropicReasoning } : {}) }; + } + if (hasContent) { + return { + content: joinedText, + ...(anthropicReasoning ? { reasoning: anthropicReasoning } : {}), + }; } - // Text blocks - const textBlock = blocks.find((b) => b.type === "text"); - if (textBlock && typeof textBlock.text === "string") { - return { content: textBlock.text }; + // Thinking-only response (no text, no tool calls) + if (anthropicReasoning) { + return { content: "", reasoning: anthropicReasoning }; } } @@ -596,9 +626,18 @@ function buildFixtureResponse( const content = candidate.content as Record | undefined; if (content && Array.isArray(content.parts)) { const parts = content.parts as Array>; - // Tool calls (functionCall) const fnCallParts = parts.filter((p) => p.functionCall); - if (fnCallParts.length > 0) { + const textParts = parts.filter((p) => typeof p.text === "string" && !p.thought); + const thoughtParts = parts.filter((p) => p.thought === true && typeof p.text === "string"); + const hasToolCalls = fnCallParts.length > 0; + const joinedText = textParts.map((p) => String(p.text ?? "")).join(""); + const hasContent = joinedText.length > 0; + const geminiReasoning = + thoughtParts.length > 0 + ? thoughtParts.map((p) => String(p.text ?? "")).join("") + : undefined; + + if (hasToolCalls) { const toolCalls: ToolCall[] = fnCallParts.map((p) => { const fc = p.functionCall as Record; return { @@ -606,13 +645,23 @@ function buildFixtureResponse( arguments: typeof fc.args === "string" ? fc.args : JSON.stringify(fc.args), }; }); - return { toolCalls }; + if (hasContent) { + return { + content: joinedText, + toolCalls, + ...(geminiReasoning ? { reasoning: geminiReasoning } : {}), + }; + } + return { toolCalls, ...(geminiReasoning ? { reasoning: geminiReasoning } : {}) }; } - // Text - const textPart = parts.find((p) => typeof p.text === "string"); - if (textPart && typeof textPart.text === "string") { - return { content: textPart.text }; + if (hasContent) { + return { + content: joinedText, + ...(geminiReasoning ? { reasoning: geminiReasoning } : {}), + }; } + // Recognized Gemini shape but empty content + return { content: "", ...(geminiReasoning ? { reasoning: geminiReasoning } : {}) }; } } @@ -623,28 +672,122 @@ function buildFixtureResponse( if (msg && Array.isArray(msg.content)) { const blocks = msg.content as Array>; const toolUseBlocks = blocks.filter((b) => b.toolUse); - if (toolUseBlocks.length > 0) { + const textBlocks = blocks.filter((b) => typeof b.text === "string"); + const reasoningBlocks = blocks.filter((b) => b.reasoningContent); + const hasToolCalls = toolUseBlocks.length > 0; + const joinedText = textBlocks.map((b) => String(b.text ?? "")).join(""); + const hasContent = joinedText.length > 0; + const bedrockReasoning = + reasoningBlocks.length > 0 + ? reasoningBlocks + .map((b) => { + const rc = b.reasoningContent as Record; + const rt = rc?.reasoningText as Record | undefined; + return String(rt?.text ?? ""); + }) + .join("") + : undefined; + + if (hasToolCalls) { const toolCalls: ToolCall[] = toolUseBlocks.map((b) => { const tu = b.toolUse as Record; return { name: String(tu.name ?? ""), arguments: typeof tu.input === "string" ? tu.input : JSON.stringify(tu.input), + ...(tu.toolUseId ? { id: String(tu.toolUseId) } : {}), }; }); - return { toolCalls }; + if (hasContent) { + return { + content: joinedText, + toolCalls, + ...(bedrockReasoning ? { reasoning: bedrockReasoning } : {}), + }; + } + return { toolCalls, ...(bedrockReasoning ? { reasoning: bedrockReasoning } : {}) }; } - const textBlock = blocks.find((b) => typeof b.text === "string"); - if (textBlock && typeof textBlock.text === "string") { - return { content: textBlock.text }; + if (hasContent) { + return { + content: joinedText, + ...(bedrockReasoning ? { reasoning: bedrockReasoning } : {}), + }; } + // Recognized Bedrock Converse shape but empty content + return { content: "", ...(bedrockReasoning ? { reasoning: bedrockReasoning } : {}) }; + } + } + + // Cohere v2 chat: { finish_reason: "...", message: { content: [{ type: "text", text: "..." }] } } + // Must come before Ollama since both have `message`, but Cohere has `finish_reason` at top level + // (not nested in `choices`) and `message.content` as an array of typed objects. + if ( + typeof obj.finish_reason === "string" && + obj.message && + typeof obj.message === "object" && + Array.isArray((obj.message as Record).content) + ) { + const msg = obj.message as Record; + const contentBlocks = msg.content as Array>; + const textBlock = contentBlocks.find((b) => b.type === "text" && typeof b.text === "string"); + const hasContent = textBlock && typeof textBlock.text === "string" && textBlock.text.length > 0; + const toolCallBlocks = contentBlocks.filter((b) => b.type === "tool_call"); + + // Also check message-level tool_calls (Cohere v2 puts tool calls here, not in content blocks) + const msgToolCalls = Array.isArray(msg.tool_calls) + ? (msg.tool_calls as Array>) + : []; + + if (toolCallBlocks.length > 0) { + const toolCalls: ToolCall[] = toolCallBlocks.map((b) => ({ + name: String(b.name ?? (b.function as Record)?.name ?? ""), + arguments: + typeof b.parameters === "string" + ? b.parameters + : typeof b.parameters === "object" + ? JSON.stringify(b.parameters) + : typeof (b.function as Record)?.arguments === "string" + ? String((b.function as Record).arguments) + : JSON.stringify((b.function as Record)?.arguments), + ...(b.id ? { id: String(b.id) } : {}), + })); + if (hasContent) { + return { content: textBlock.text as string, toolCalls }; + } + return { toolCalls }; + } + if (msgToolCalls.length > 0) { + const toolCalls: ToolCall[] = msgToolCalls.map((tc) => { + const fn = tc.function as Record | undefined; + return { + name: String(tc.name ?? fn?.name ?? ""), + arguments: + typeof tc.parameters === "string" + ? tc.parameters + : typeof tc.parameters === "object" + ? JSON.stringify(tc.parameters) + : typeof fn?.arguments === "string" + ? String(fn.arguments) + : JSON.stringify(fn?.arguments), + ...(tc.id ? { id: String(tc.id) } : {}), + }; + }); + if (hasContent) { + return { content: textBlock.text as string, toolCalls }; + } + return { toolCalls }; + } + if (hasContent) { + return { content: textBlock.text as string }; } } // Ollama: { message: { content: "...", tool_calls: [...] } } if (obj.message && typeof obj.message === "object") { const msg = obj.message as Record; - // Tool calls (check before content — Ollama sends content: "" alongside tool_calls) - if (Array.isArray(msg.tool_calls) && msg.tool_calls.length > 0) { + const hasOllamaToolCalls = Array.isArray(msg.tool_calls) && msg.tool_calls.length > 0; + const hasOllamaContent = typeof msg.content === "string" && msg.content.length > 0; + + if (hasOllamaToolCalls) { const toolCalls: ToolCall[] = (msg.tool_calls as Array>) .filter((tc) => tc.function != null) .map((tc) => { @@ -655,10 +798,13 @@ function buildFixtureResponse( typeof fn.arguments === "string" ? fn.arguments : JSON.stringify(fn.arguments), }; }); + if (hasOllamaContent) { + return { content: msg.content as string, toolCalls }; + } return { toolCalls }; } - if (typeof msg.content === "string" && msg.content.length > 0) { - return { content: msg.content }; + if (hasOllamaContent) { + return { content: msg.content as string }; } // Ollama message with content array (like Cohere) if (Array.isArray(msg.content) && msg.content.length > 0) { @@ -669,6 +815,11 @@ function buildFixtureResponse( } } + // Ollama /api/generate: { response: "...", done: true/false } + if (typeof obj.response === "string" && "done" in obj) { + return { content: obj.response }; + } + // Fallback: unknown format — save as error return { error: { From b0d9a49d9addc6d3dc6be23d8d6e649153c0be55 Mon Sep 17 00:00:00 2001 From: Jordan Ritter Date: Thu, 23 Apr 2026 16:23:16 -0700 Subject: [PATCH 3/6] =?UTF-8?q?fix:=20provider=20handlers=20=E2=80=94=20ov?= =?UTF-8?q?errides,=20parity,=20stream=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bedrock/Converse: ContentWithToolCallsResponse, ResponseOverrides in all builders + streaming, Converse-wrapped stream events, text_delta type, error envelope, webSearches warnings. Cohere v2: reasoning, webSearches, response_format, tool_calls, full overrides. Ollama: content+toolCalls, stream default, validation. Gemini: callCounter collision fix, thought filtering. All: journal source field, Azure embedding routing. --- src/__tests__/bedrock-stream.test.ts | 75 ++- src/__tests__/bedrock.test.ts | 327 ++++++++++++- src/__tests__/cohere.test.ts | 322 ++++++++++++- src/__tests__/ollama.test.ts | 16 +- src/__tests__/provider-compat.test.ts | 10 +- src/__tests__/reasoning-all-providers.test.ts | 65 ++- src/bedrock-converse.ts | 228 ++++++++- src/bedrock.ts | 366 +++++++++++++-- src/cohere.ts | 434 ++++++++++++++++-- src/embeddings.ts | 41 +- src/gemini.ts | 31 +- src/images.ts | 25 +- src/messages.ts | 9 +- src/ollama.ts | 183 +++++++- src/responses.ts | 23 +- src/speech.ts | 25 +- src/transcription.ts | 16 +- src/video.ts | 97 +++- 18 files changed, 2069 insertions(+), 224 deletions(-) diff --git a/src/__tests__/bedrock-stream.test.ts b/src/__tests__/bedrock-stream.test.ts index 349dea0..8a1a3bd 100644 --- a/src/__tests__/bedrock-stream.test.ts +++ b/src/__tests__/bedrock-stream.test.ts @@ -263,17 +263,24 @@ describe("POST /model/{modelId}/invoke-with-response-stream", () => { // messageStart expect(frames[0].eventType).toBe("messageStart"); - expect(frames[0].payload).toEqual({ role: "assistant" }); + expect(frames[0].payload).toEqual({ messageStart: { role: "assistant" } }); // contentBlockStart expect(frames[1].eventType).toBe("contentBlockStart"); - expect(frames[1].payload).toEqual({ contentBlockIndex: 0, start: {} }); + expect(frames[1].payload).toEqual({ + contentBlockIndex: 0, + contentBlockStart: { contentBlockIndex: 0, start: { type: "text" } }, + }); // Content delta(s) — collect text const deltas = frames.filter((f) => f.eventType === "contentBlockDelta"); expect(deltas.length).toBeGreaterThanOrEqual(1); const fullText = deltas - .map((f) => (f.payload as { delta: { text: string } }).delta.text) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { text: string } } }).contentBlockDelta.delta + .text, + ) .join(""); expect(fullText).toBe("Hi there!"); @@ -301,23 +308,30 @@ describe("POST /model/{modelId}/invoke-with-response-stream", () => { // messageStart expect(frames[0].eventType).toBe("messageStart"); - expect(frames[0].payload).toEqual({ role: "assistant" }); + expect(frames[0].payload).toEqual({ messageStart: { role: "assistant" } }); // contentBlockStart with toolUse expect(frames[1].eventType).toBe("contentBlockStart"); const startPayload = frames[1].payload as { contentBlockIndex: number; - start: { toolUse: { toolUseId: string; name: string } }; + contentBlockStart: { + contentBlockIndex: number; + start: { toolUse: { toolUseId: string; name: string } }; + }; }; expect(startPayload.contentBlockIndex).toBe(0); - expect(startPayload.start.toolUse.name).toBe("get_weather"); - expect(startPayload.start.toolUse.toolUseId).toBeDefined(); + expect(startPayload.contentBlockStart.start.toolUse.name).toBe("get_weather"); + expect(startPayload.contentBlockStart.start.toolUse.toolUseId).toBeDefined(); - // contentBlockDelta(s) with input_json_delta + // contentBlockDelta(s) with toolUse input const deltas = frames.filter((f) => f.eventType === "contentBlockDelta"); expect(deltas.length).toBeGreaterThanOrEqual(1); const fullJson = deltas - .map((f) => (f.payload as { delta: { inputJSON: string } }).delta.inputJSON) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { toolUse: { input: string } } } }) + .contentBlockDelta.delta.toolUse.input, + ) .join(""); expect(JSON.parse(fullJson)).toEqual({ city: "SF" }); @@ -460,18 +474,24 @@ describe("POST /model/{modelId}/invoke-with-response-stream (multiple tool calls // First tool at contentBlockIndex 0 const start0 = blockStarts[0].payload as { contentBlockIndex: number; - start: { toolUse: { name: string } }; + contentBlockStart: { + contentBlockIndex: number; + start: { toolUse: { name: string } }; + }; }; expect(start0.contentBlockIndex).toBe(0); - expect(start0.start.toolUse.name).toBe("get_weather"); + expect(start0.contentBlockStart.start.toolUse.name).toBe("get_weather"); // Second tool at contentBlockIndex 1 const start1 = blockStarts[1].payload as { contentBlockIndex: number; - start: { toolUse: { name: string } }; + contentBlockStart: { + contentBlockIndex: number; + start: { toolUse: { name: string } }; + }; }; expect(start1.contentBlockIndex).toBe(1); - expect(start1.start.toolUse.name).toBe("get_time"); + expect(start1.contentBlockStart.start.toolUse.name).toBe("get_time"); // contentBlockStop should also have correct indices const blockStops = frames.filter((f) => f.eventType === "contentBlockStop"); @@ -623,13 +643,17 @@ describe("POST /model/{modelId}/converse-stream", () => { // Verify event sequence expect(frames[0].eventType).toBe("messageStart"); - expect(frames[0].payload).toEqual({ role: "assistant" }); + expect(frames[0].payload).toEqual({ messageStart: { role: "assistant" } }); expect(frames[1].eventType).toBe("contentBlockStart"); const deltas = frames.filter((f) => f.eventType === "contentBlockDelta"); const fullText = deltas - .map((f) => (f.payload as { delta: { text: string } }).delta.text) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { text: string } } }).contentBlockDelta.delta + .text, + ) .join(""); expect(fullText).toBe("Hi there!"); @@ -651,13 +675,20 @@ describe("POST /model/{modelId}/converse-stream", () => { const startFrame = frames.find((f) => f.eventType === "contentBlockStart"); const startPayload = startFrame!.payload as { contentBlockIndex: number; - start: { toolUse: { toolUseId: string; name: string } }; + contentBlockStart: { + contentBlockIndex: number; + start: { toolUse: { toolUseId: string; name: string } }; + }; }; - expect(startPayload.start.toolUse.name).toBe("get_weather"); + expect(startPayload.contentBlockStart.start.toolUse.name).toBe("get_weather"); const deltas = frames.filter((f) => f.eventType === "contentBlockDelta"); const fullJson = deltas - .map((f) => (f.payload as { delta: { inputJSON: string } }).delta.inputJSON) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { toolUse: { input: string } } } }) + .contentBlockDelta.delta.toolUse.input, + ) .join(""); expect(JSON.parse(fullJson)).toEqual({ city: "SF" }); @@ -994,12 +1025,14 @@ describe("POST /model/{modelId}/invoke-with-response-stream (malformed tool args expect(res.status).toBe(200); const frames = parseFrames(res.body); - // Find contentBlockDelta frames with inputJSON + // Find contentBlockDelta frames with toolUse input const deltas = frames.filter((f) => f.eventType === "contentBlockDelta"); const fullJson = deltas .map((f) => { - const payload = f.payload as { delta: { inputJSON?: string } }; - return payload.delta.inputJSON ?? ""; + const payload = f.payload as { + contentBlockDelta: { delta: { toolUse?: { input: string } } }; + }; + return payload.contentBlockDelta.delta.toolUse?.input ?? ""; }) .join(""); // Malformed arguments should fall back to "{}" diff --git a/src/__tests__/bedrock.test.ts b/src/__tests__/bedrock.test.ts index 5fc47d9..f38ab40 100644 --- a/src/__tests__/bedrock.test.ts +++ b/src/__tests__/bedrock.test.ts @@ -1445,7 +1445,6 @@ describe("POST /model/{modelId}/invoke (error fixture no error type)", () => { // --------------------------------------------------------------------------- import { buildBedrockStreamTextEvents, buildBedrockStreamToolCallEvents } from "../bedrock.js"; -import { Logger } from "../logger.js"; describe("buildBedrockStreamTextEvents", () => { it("creates correct event sequence for empty content", () => { @@ -1462,9 +1461,18 @@ describe("buildBedrockStreamTextEvents", () => { const events = buildBedrockStreamTextEvents("ABCDEF", 2); const deltas = events.filter((e) => e.eventType === "contentBlockDelta"); expect(deltas).toHaveLength(3); - expect((deltas[0].payload as { delta: { text: string } }).delta.text).toBe("AB"); - expect((deltas[1].payload as { delta: { text: string } }).delta.text).toBe("CD"); - expect((deltas[2].payload as { delta: { text: string } }).delta.text).toBe("EF"); + expect( + (deltas[0].payload as { contentBlockDelta: { delta: { text: string } } }).contentBlockDelta + .delta.text, + ).toBe("AB"); + expect( + (deltas[1].payload as { contentBlockDelta: { delta: { text: string } } }).contentBlockDelta + .delta.text, + ).toBe("CD"); + expect( + (deltas[2].payload as { contentBlockDelta: { delta: { text: string } } }).contentBlockDelta + .delta.text, + ).toBe("EF"); }); }); @@ -1479,7 +1487,11 @@ describe("buildBedrockStreamToolCallEvents", () => { ); const deltas = events.filter((e) => e.eventType === "contentBlockDelta"); const fullJson = deltas - .map((e) => (e.payload as { delta: { inputJSON: string } }).delta.inputJSON) + .map( + (e) => + (e.payload as { contentBlockDelta: { delta: { toolUse: { input: string } } } }) + .contentBlockDelta.delta.toolUse.input, + ) .join(""); expect(fullJson).toBe("{}"); }); @@ -1492,9 +1504,9 @@ describe("buildBedrockStreamToolCallEvents", () => { ); const startEvent = events.find((e) => e.eventType === "contentBlockStart"); const payload = startEvent!.payload as { - start: { toolUse: { toolUseId: string } }; + contentBlockStart: { start: { toolUse: { toolUseId: string } } }; }; - expect(payload.start.toolUse.toolUseId).toMatch(/^toolu_/); + expect(payload.contentBlockStart.start.toolUse.toolUseId).toMatch(/^toolu_/); }); it("uses provided tool id", () => { @@ -1505,16 +1517,20 @@ describe("buildBedrockStreamToolCallEvents", () => { ); const startEvent = events.find((e) => e.eventType === "contentBlockStart"); const payload = startEvent!.payload as { - start: { toolUse: { toolUseId: string } }; + contentBlockStart: { start: { toolUse: { toolUseId: string } } }; }; - expect(payload.start.toolUse.toolUseId).toBe("custom_id"); + expect(payload.contentBlockStart.start.toolUse.toolUseId).toBe("custom_id"); }); it("uses '{}' when arguments is empty string", () => { const events = buildBedrockStreamToolCallEvents([{ name: "fn", arguments: "" }], 100, logger); const deltas = events.filter((e) => e.eventType === "contentBlockDelta"); const fullJson = deltas - .map((e) => (e.payload as { delta: { inputJSON: string } }).delta.inputJSON) + .map( + (e) => + (e.payload as { contentBlockDelta: { delta: { toolUse: { input: string } } } }) + .contentBlockDelta.delta.toolUse.input, + ) .join(""); expect(fullJson).toBe("{}"); }); @@ -1557,3 +1573,294 @@ describe("POST /model/{modelId}/invoke (strict mode)", () => { expect(body.content[0].text).toBe("Hi there!"); }); }); + +// ─── Bedrock ResponseOverrides ───────────────────────────────────────────── + +describe("Bedrock ResponseOverrides", () => { + it("applies id/model/finishReason overrides on invoke text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov" }, + response: { + content: "Overridden.", + id: "msg-custom-123", + model: "custom-model", + finishReason: "length", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke`, + { + anthropic_version: "bedrock-2023-05-31", + max_tokens: 512, + messages: [{ role: "user", content: "ov" }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.id).toBe("msg-custom-123"); + expect(body.model).toBe("custom-model"); + expect(body.stop_reason).toBe("max_tokens"); + }); + + it("applies usage overrides on invoke text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-usage" }, + response: { + content: "Usage test.", + usage: { input_tokens: 15, output_tokens: 25 }, + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke`, + { + anthropic_version: "bedrock-2023-05-31", + max_tokens: 512, + messages: [{ role: "user", content: "ov-usage" }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.usage.input_tokens).toBe(15); + expect(body.usage.output_tokens).toBe(25); + }); + + it("applies overrides on invoke tool call response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-tool" }, + response: { + toolCalls: [{ name: "fn", arguments: '{"x":1}' }], + id: "tc-ov-id", + finishReason: "stop", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke`, + { + anthropic_version: "bedrock-2023-05-31", + max_tokens: 512, + messages: [{ role: "user", content: "ov-tool" }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.id).toBe("tc-ov-id"); + expect(body.stop_reason).toBe("end_turn"); + }); + + it("applies overrides on invoke content+toolCalls response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-cwtc" }, + response: { + content: "Here is the result.", + toolCalls: [{ name: "fn", arguments: '{"a":1}' }], + id: "cwtc-ov-id", + finishReason: "tool_calls", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke`, + { + anthropic_version: "bedrock-2023-05-31", + max_tokens: 512, + messages: [{ role: "user", content: "ov-cwtc" }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.id).toBe("cwtc-ov-id"); + expect(body.stop_reason).toBe("tool_use"); + }); +}); + +// ─── Bedrock webSearches warning ─────────────────────────────────────────── + +describe("Bedrock webSearches warning", () => { + it("logs warning for text response with webSearches on invoke", async () => { + const warnings: string[] = []; + const logger = new Logger("silent"); + logger.warn = (msg: string) => { + warnings.push(msg); + }; + + const fixture: Fixture = { + match: { userMessage: "web" }, + response: { content: "Result.", webSearches: ["test"] }, + }; + const journal = new Journal(); + const req = { + method: undefined, + url: undefined, + headers: {}, + } as unknown as http.IncomingMessage; + const res = { + _written: "", + writableEnded: false, + statusCode: 0, + writeHead(s: number) { + this.statusCode = s; + }, + setHeader() {}, + write(d: string) { + this._written += d; + return true; + }, + end(d?: string) { + if (d) this._written += d; + this.writableEnded = true; + }, + destroy() { + this.writableEnded = true; + }, + } as unknown as http.ServerResponse; + + await handleBedrock( + req, + res, + JSON.stringify({ + anthropic_version: "bedrock-2023-05-31", + max_tokens: 512, + messages: [{ role: "user", content: "web" }], + }), + "anthropic.claude-3-5-sonnet-20241022-v2:0", + [fixture], + journal, + { + latency: 0, + chunkSize: 100, + logger, + } as HandlerDefaults, + () => {}, + ); + + expect(warnings.some((w) => w.includes("webSearches") && w.includes("Bedrock"))).toBe(true); + }); +}); + +// ─── Bedrock Converse ResponseOverrides ──────────────────────────────────── + +describe("Bedrock Converse ResponseOverrides", () => { + it("applies finishReason override on converse text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "conv-ov" }, + response: { + content: "Overridden.", + finishReason: "length", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/converse`, + { + messages: [{ role: "user", content: [{ text: "conv-ov" }] }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.stopReason).toBe("max_tokens"); + }); + + it("applies usage override on converse text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "conv-usage" }, + response: { + content: "Usage test.", + usage: { input_tokens: 5, output_tokens: 10 }, + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/converse`, + { + messages: [{ role: "user", content: [{ text: "conv-usage" }] }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.usage.inputTokens).toBe(5); + expect(body.usage.outputTokens).toBe(10); + expect(body.usage.totalTokens).toBe(15); + }); + + it("applies overrides on converse tool call response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "conv-tool" }, + response: { + toolCalls: [{ name: "fn", arguments: '{"a":1}' }], + finishReason: "stop", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/converse`, + { + messages: [{ role: "user", content: [{ text: "conv-tool" }] }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.stopReason).toBe("end_turn"); + }); + + it("applies overrides on converse content+toolCalls response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "conv-cwtc" }, + response: { + content: "Some text.", + toolCalls: [{ name: "fn", arguments: '{"a":1}' }], + finishReason: "tool_calls", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/converse`, + { + messages: [{ role: "user", content: [{ text: "conv-cwtc" }] }], + }, + ); + expect(res.status).toBe(200); + const body = JSON.parse(res.body); + expect(body.stopReason).toBe("tool_use"); + }); +}); + +// ─── Bedrock Converse webSearches warning ────────────────────────────────── + +describe("Bedrock Converse webSearches warning", () => { + it("logs warning for text response with webSearches on converse", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "conv-web" }, + response: { content: "Result.", webSearches: ["test"] }, + }, + ]; + instance = await createServer(fixtures); + const res = await post( + `${instance.url}/model/anthropic.claude-3-5-sonnet-20241022-v2:0/converse`, + { + messages: [{ role: "user", content: [{ text: "conv-web" }] }], + }, + ); + // Should still succeed — webSearches is just ignored with a warning + expect(res.status).toBe(200); + }); +}); diff --git a/src/__tests__/cohere.test.ts b/src/__tests__/cohere.test.ts index e954dd5..a59549c 100644 --- a/src/__tests__/cohere.test.ts +++ b/src/__tests__/cohere.test.ts @@ -829,8 +829,8 @@ describe("POST /v2/chat (malformed tool call arguments)", () => { const body = JSON.parse(res.body); expect(body.message.tool_calls).toHaveLength(1); expect(body.message.tool_calls[0].function.name).toBe("fn"); - // Cohere passes through the arguments string as-is (logs warning) - expect(body.message.tool_calls[0].function.arguments).toBe("NOT VALID JSON"); + // Malformed JSON falls back to "{}" (logs warning) + expect(body.message.tool_calls[0].function.arguments).toBe("{}"); }); }); @@ -1416,3 +1416,321 @@ describe("handleCohere (direct handler call, method/url fallbacks)", () => { expect(entry!.response.status).toBe(500); }); }); + +// ─── Cohere reasoning support ────────────────────────────────────────────── + +describe("Cohere reasoning support", () => { + it("includes reasoning as text block in non-streaming text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "think" }, + response: { content: "The answer is 42.", reasoning: "Let me reason step by step..." }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "think" }], + }); + expect(res.status).toBe(200); + const json = JSON.parse(res.body); + expect(json.message.content).toHaveLength(2); + expect(json.message.content[0].text).toBe("Let me reason step by step..."); + expect(json.message.content[1].text).toBe("The answer is 42."); + expect(json.finish_reason).toBe("COMPLETE"); + }); + + it("includes reasoning blocks in streaming text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "think-stream" }, + response: { content: "Result.", reasoning: "Thinking..." }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "think-stream" }], + stream: true, + }); + expect(res.status).toBe(200); + const events = parseSSEEvents(res.body); + + // Should have content-start/delta/end for reasoning (index 0) then content (index 1) + const contentDeltas = events.filter((e) => e.event === "content-delta"); + expect(contentDeltas.length).toBeGreaterThanOrEqual(2); + // First content delta should be the reasoning text + const firstDelta = contentDeltas[0].data as { + delta: { message: { content: { text: string } } }; + }; + expect(firstDelta.delta.message.content.text).toBe("Thinking..."); + }); + + it("includes reasoning in content+toolCalls non-streaming response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "think-tool" }, + response: { + content: "Let me check.", + toolCalls: [{ name: "lookup", arguments: '{"q":"test"}' }], + reasoning: "Need to look this up.", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "think-tool" }], + }); + expect(res.status).toBe(200); + const json = JSON.parse(res.body); + // reasoning block + text block + expect(json.message.content.length).toBeGreaterThanOrEqual(2); + expect(json.message.content[0].text).toBe("Need to look this up."); + expect(json.message.content[1].text).toBe("Let me check."); + expect(json.message.tool_calls.length).toBe(1); + expect(json.finish_reason).toBe("TOOL_CALL"); + }); +}); + +// ─── Cohere webSearches warning ──────────────────────────────────────────── + +describe("Cohere webSearches warning", () => { + it("logs warning when text response has webSearches", async () => { + const warnings: string[] = []; + const logger = new Logger("silent"); + logger.warn = (msg: string) => { + warnings.push(msg); + }; + + const fixture: Fixture = { + match: { userMessage: "web" }, + response: { content: "Result.", webSearches: ["test"] }, + }; + const journal = new Journal(); + const req = createMockReq(); + const res = createMockRes(); + + await handleCohere( + req, + res, + JSON.stringify({ model: "cmd-r", messages: [{ role: "user", content: "web" }] }), + [fixture], + journal, + createDefaults({ logger }), + () => {}, + ); + + expect(warnings.some((w) => w.includes("webSearches") && w.includes("Cohere"))).toBe(true); + }); + + it("logs warning when content+toolCalls response has webSearches", async () => { + const warnings: string[] = []; + const logger = new Logger("silent"); + logger.warn = (msg: string) => { + warnings.push(msg); + }; + + const fixture: Fixture = { + match: { userMessage: "web-tool" }, + response: { + content: "Here.", + toolCalls: [{ name: "fn", arguments: "{}" }], + webSearches: ["test"], + }, + }; + const journal = new Journal(); + const req = createMockReq(); + const res = createMockRes(); + + await handleCohere( + req, + res, + JSON.stringify({ model: "cmd-r", messages: [{ role: "user", content: "web-tool" }] }), + [fixture], + journal, + createDefaults({ logger }), + () => {}, + ); + + expect(warnings.some((w) => w.includes("webSearches") && w.includes("Cohere"))).toBe(true); + }); +}); + +// ─── Cohere response_format forwarding ───────────────────────────────────── + +describe("Cohere response_format forwarding", () => { + it("forwards response_format to ChatCompletionRequest", () => { + const result = cohereToCompletionRequest({ + model: "command-r-plus", + messages: [{ role: "user", content: "hello" }], + response_format: { type: "json_object" }, + } as Parameters[0]); + expect(result.response_format).toEqual({ type: "json_object" }); + }); + + it("omits response_format when not provided", () => { + const result = cohereToCompletionRequest({ + model: "command-r-plus", + messages: [{ role: "user", content: "hello" }], + } as Parameters[0]); + expect(result.response_format).toBeUndefined(); + }); +}); + +// ─── Cohere assistant tool_calls mapping ─────────────────────────────────── + +describe("Cohere assistant tool_calls mapping", () => { + it("maps assistant tool_calls to ChatCompletionRequest format", () => { + const result = cohereToCompletionRequest({ + model: "command-r-plus", + messages: [ + { role: "user", content: "hi" }, + { + role: "assistant", + content: "Using tool", + tool_calls: [ + { + id: "tc-1", + type: "function", + function: { name: "get_weather", arguments: '{"city":"SF"}' }, + }, + ], + }, + { role: "tool", content: "72F", tool_call_id: "tc-1" }, + { role: "user", content: "thanks" }, + ], + } as Parameters[0]); + + const assistantMsg = result.messages.find( + (m) => m.role === "assistant" && m.tool_calls && m.tool_calls.length > 0, + ); + expect(assistantMsg).toBeDefined(); + expect(assistantMsg!.tool_calls).toHaveLength(1); + expect(assistantMsg!.tool_calls![0].function.name).toBe("get_weather"); + expect(assistantMsg!.tool_calls![0].function.arguments).toBe('{"city":"SF"}'); + expect(assistantMsg!.tool_calls![0].id).toBe("tc-1"); + expect(assistantMsg!.content).toBe("Using tool"); + }); + + it("generates tool_call id when not provided", () => { + const result = cohereToCompletionRequest({ + model: "command-r-plus", + messages: [ + { role: "user", content: "hi" }, + { + role: "assistant", + content: "", + tool_calls: [ + { + type: "function", + function: { name: "fn", arguments: "{}" }, + }, + ], + }, + ], + } as Parameters[0]); + + const assistantMsg = result.messages.find( + (m) => m.role === "assistant" && m.tool_calls && m.tool_calls.length > 0, + ); + expect(assistantMsg!.tool_calls![0].id).toBeTruthy(); + }); + + it("falls back to plain assistant message when no tool_calls present", () => { + const result = cohereToCompletionRequest({ + model: "command-r-plus", + messages: [ + { role: "user", content: "hi" }, + { role: "assistant", content: "just text" }, + ], + } as Parameters[0]); + + const assistantMsg = result.messages.find((m) => m.role === "assistant"); + expect(assistantMsg!.content).toBe("just text"); + expect(assistantMsg!.tool_calls).toBeUndefined(); + }); +}); + +// ─── Cohere ResponseOverrides ────────────────────────────────────────────── + +describe("Cohere ResponseOverrides", () => { + it("applies id override on non-streaming text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-id" }, + response: { content: "Hi!", id: "custom-id-123" }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "ov-id" }], + }); + expect(res.status).toBe(200); + const json = JSON.parse(res.body); + expect(json.id).toBe("custom-id-123"); + }); + + it("applies finishReason override on non-streaming text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-fr" }, + response: { content: "Done.", finishReason: "length" }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "ov-fr" }], + }); + expect(res.status).toBe(200); + const json = JSON.parse(res.body); + expect(json.finish_reason).toBe("MAX_TOKENS"); + }); + + it("applies usage override on non-streaming text response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-usage" }, + response: { + content: "Done.", + usage: { prompt_tokens: 10, completion_tokens: 20 }, + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "ov-usage" }], + }); + expect(res.status).toBe(200); + const json = JSON.parse(res.body); + expect(json.usage.tokens.input_tokens).toBe(10); + expect(json.usage.tokens.output_tokens).toBe(20); + expect(json.usage.billed_units.input_tokens).toBe(10); + expect(json.usage.billed_units.output_tokens).toBe(20); + }); + + it("applies overrides on non-streaming tool call response", async () => { + const fixtures: Fixture[] = [ + { + match: { userMessage: "ov-tc" }, + response: { + toolCalls: [{ name: "fn", arguments: '{"a":1}' }], + id: "tc-override-id", + finishReason: "stop", + }, + }, + ]; + instance = await createServer(fixtures); + const res = await post(`${instance.url}/v2/chat`, { + model: "command-r-plus", + messages: [{ role: "user", content: "ov-tc" }], + }); + expect(res.status).toBe(200); + const json = JSON.parse(res.body); + expect(json.id).toBe("tc-override-id"); + expect(json.finish_reason).toBe("COMPLETE"); + }); +}); diff --git a/src/__tests__/ollama.test.ts b/src/__tests__/ollama.test.ts index 13b9f7b..f92b634 100644 --- a/src/__tests__/ollama.test.ts +++ b/src/__tests__/ollama.test.ts @@ -938,7 +938,7 @@ describe("POST /api/chat (malformed tool call arguments)", () => { // ─── Integration tests: tool call on /api/generate → 500 ─────────────────── describe("POST /api/generate (tool call fixture)", () => { - it("returns 500 'unknown type' for tool call fixtures on /api/generate", async () => { + it("returns 400 for tool call fixtures on /api/generate with clear error", async () => { const tcFixture: Fixture = { match: { userMessage: "tool-gen" }, response: { @@ -952,9 +952,9 @@ describe("POST /api/generate (tool call fixture)", () => { stream: false, }); - expect(res.status).toBe(500); + expect(res.status).toBe(400); const body = JSON.parse(res.body); - expect(body.error.message).toContain("did not match any known type"); + expect(body.error.message).toContain("Tool call fixtures are not supported on /api/generate"); }); }); @@ -1100,7 +1100,7 @@ describe("POST /api/generate (malformed JSON)", () => { // ─── Integration tests: POST /api/generate (unknown response type streaming) ─ describe("POST /api/generate (unknown response type streaming)", () => { - it("returns 500 for tool call fixture on /api/generate (streaming default)", async () => { + it("returns 400 for tool call fixture on /api/generate (streaming default)", async () => { const tcFixture: Fixture = { match: { userMessage: "tool-gen-stream" }, response: { @@ -1114,9 +1114,9 @@ describe("POST /api/generate (unknown response type streaming)", () => { // stream omitted → defaults to true }); - expect(res.status).toBe(500); + expect(res.status).toBe(400); const body = JSON.parse(res.body); - expect(body.error.message).toContain("did not match any known type"); + expect(body.error.message).toContain("Tool call fixtures are not supported on /api/generate"); }); }); @@ -1262,12 +1262,12 @@ describe("ollamaToCompletionRequest (edge cases)", () => { expect(result.max_tokens).toBeUndefined(); }); - it("handles stream undefined (passes through as undefined)", () => { + it("defaults stream to true when absent (matches Ollama default)", () => { const result = ollamaToCompletionRequest({ model: "llama3", messages: [{ role: "user", content: "hi" }], }); - expect(result.stream).toBeUndefined(); + expect(result.stream).toBe(true); }); it("handles empty tools array (returns undefined)", () => { diff --git a/src/__tests__/provider-compat.test.ts b/src/__tests__/provider-compat.test.ts index a6d9897..93a820e 100644 --- a/src/__tests__/provider-compat.test.ts +++ b/src/__tests__/provider-compat.test.ts @@ -303,11 +303,13 @@ describe("OpenAI-compatible path prefix normalization", () => { stream: false, }); - // Normalization works: we get "No fixture matched" from the Responses handler - // (not "Not found" which would mean the path wasn't routed at all) + // Normalization works: the Responses handler receives the request, + // correctly parses the string input, matches the fixture, and returns + // a valid Responses API envelope. const parsed = JSON.parse(body); - expect(parsed.error.type).toBe("invalid_request_error"); - expect(parsed.error.code).toBe("no_fixture_match"); + expect(parsed.object).toBe("response"); + expect(parsed.output).toBeDefined(); + expect(parsed.output.length).toBeGreaterThan(0); }); it("normalizes /custom/audio/speech to /v1/audio/speech", async () => { diff --git a/src/__tests__/reasoning-all-providers.test.ts b/src/__tests__/reasoning-all-providers.test.ts index 42657c2..303c6f0 100644 --- a/src/__tests__/reasoning-all-providers.test.ts +++ b/src/__tests__/reasoning-all-providers.test.ts @@ -442,12 +442,14 @@ describe("POST /model/{id}/invoke-with-response-stream (reasoning streaming)", ( const thinkingStartIdx = frames.findIndex( (f) => f.eventType === "contentBlockStart" && - (f.payload as { start?: { type?: string } }).start?.type === "thinking", + (f.payload as { contentBlockStart?: { start?: { type?: string } } }).contentBlockStart + ?.start?.type === "thinking", ); const textStartIdx = frames.findIndex( (f) => f.eventType === "contentBlockStart" && - (f.payload as { start?: { type?: string } }).start?.type === undefined, + (f.payload as { contentBlockStart?: { start?: { type?: string } } }).contentBlockStart + ?.start?.type === "text", ); expect(thinkingStartIdx).toBeGreaterThan(0); @@ -457,10 +459,15 @@ describe("POST /model/{id}/invoke-with-response-stream (reasoning streaming)", ( const thinkingDeltas = frames.filter( (f) => f.eventType === "contentBlockDelta" && - (f.payload as { delta?: { type?: string } }).delta?.type === "thinking_delta", + (f.payload as { contentBlockDelta?: { delta?: { type?: string } } }).contentBlockDelta + ?.delta?.type === "thinking_delta", ); const fullThinking = thinkingDeltas - .map((f) => (f.payload as { delta: { thinking: string } }).delta.thinking) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { thinking: string } } }).contentBlockDelta + .delta.thinking, + ) .join(""); expect(fullThinking).toBe("Let me think step by step about this problem."); @@ -468,10 +475,15 @@ describe("POST /model/{id}/invoke-with-response-stream (reasoning streaming)", ( const textDeltas = frames.filter( (f) => f.eventType === "contentBlockDelta" && - (f.payload as { delta?: { type?: string } }).delta?.type === "text_delta", + typeof (f.payload as { contentBlockDelta?: { delta?: { text?: string } } }) + .contentBlockDelta?.delta?.text === "string", ); const fullText = textDeltas - .map((f) => (f.payload as { delta: { text: string } }).delta.text) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { text: string } } }).contentBlockDelta.delta + .text, + ) .join(""); expect(fullText).toBe("The answer is 42."); @@ -495,7 +507,8 @@ describe("POST /model/{id}/invoke-with-response-stream (reasoning streaming)", ( const thinkingDeltas = frames.filter( (f) => f.eventType === "contentBlockDelta" && - (f.payload as { delta?: { type?: string } }).delta?.type === "thinking_delta", + (f.payload as { contentBlockDelta?: { delta?: { type?: string } } }).contentBlockDelta + ?.delta?.type === "thinking_delta", ); expect(thinkingDeltas).toHaveLength(0); }); @@ -519,12 +532,14 @@ describe("POST /model/{id}/converse-stream (reasoning streaming)", () => { const thinkingStartIdx = frames.findIndex( (f) => f.eventType === "contentBlockStart" && - (f.payload as { start?: { type?: string } }).start?.type === "thinking", + (f.payload as { contentBlockStart?: { start?: { type?: string } } }).contentBlockStart + ?.start?.type === "thinking", ); const textStartIdx = frames.findIndex( (f) => f.eventType === "contentBlockStart" && - (f.payload as { start?: { type?: string } }).start?.type === undefined, + (f.payload as { contentBlockStart?: { start?: { type?: string } } }).contentBlockStart + ?.start?.type === "text", ); expect(thinkingStartIdx).toBeGreaterThan(0); @@ -534,10 +549,15 @@ describe("POST /model/{id}/converse-stream (reasoning streaming)", () => { const thinkingDeltas = frames.filter( (f) => f.eventType === "contentBlockDelta" && - (f.payload as { delta?: { type?: string } }).delta?.type === "thinking_delta", + (f.payload as { contentBlockDelta?: { delta?: { type?: string } } }).contentBlockDelta + ?.delta?.type === "thinking_delta", ); const fullThinking = thinkingDeltas - .map((f) => (f.payload as { delta: { thinking: string } }).delta.thinking) + .map( + (f) => + (f.payload as { contentBlockDelta: { delta: { thinking: string } } }).contentBlockDelta + .delta.thinking, + ) .join(""); expect(fullThinking).toBe("Let me think step by step about this problem."); @@ -555,7 +575,8 @@ describe("POST /model/{id}/converse-stream (reasoning streaming)", () => { const thinkingDeltas = frames.filter( (f) => f.eventType === "contentBlockDelta" && - (f.payload as { delta?: { type?: string } }).delta?.type === "thinking_delta", + (f.payload as { contentBlockDelta?: { delta?: { type?: string } } }).contentBlockDelta + ?.delta?.type === "thinking_delta", ); expect(thinkingDeltas).toHaveLength(0); }); @@ -732,13 +753,19 @@ describe("buildBedrockStreamTextEvents (reasoning)", () => { // Thinking block at index 0 expect(events[1]).toEqual({ eventType: "contentBlockStart", - payload: { contentBlockIndex: 0, start: { type: "thinking" } }, + payload: { + contentBlockIndex: 0, + contentBlockStart: { contentBlockIndex: 0, start: { type: "thinking" } }, + }, }); expect(events[2]).toEqual({ eventType: "contentBlockDelta", payload: { contentBlockIndex: 0, - delta: { type: "thinking_delta", thinking: "Step by step." }, + contentBlockDelta: { + contentBlockIndex: 0, + delta: { type: "thinking_delta", thinking: "Step by step." }, + }, }, }); expect(events[3]).toEqual({ @@ -749,13 +776,19 @@ describe("buildBedrockStreamTextEvents (reasoning)", () => { // Text block at index 1 expect(events[4]).toEqual({ eventType: "contentBlockStart", - payload: { contentBlockIndex: 1, start: {} }, + payload: { + contentBlockIndex: 1, + contentBlockStart: { contentBlockIndex: 1, start: { type: "text" } }, + }, }); expect(events[5]).toEqual({ eventType: "contentBlockDelta", payload: { contentBlockIndex: 1, - delta: { type: "text_delta", text: "The answer." }, + contentBlockDelta: { + contentBlockIndex: 1, + delta: { type: "text_delta", text: "The answer." }, + }, }, }); expect(events[6]).toEqual({ diff --git a/src/bedrock-converse.ts b/src/bedrock-converse.ts index 3d91357..605126d 100644 --- a/src/bedrock-converse.ts +++ b/src/bedrock-converse.ts @@ -13,13 +13,16 @@ import type { ChatMessage, Fixture, HandlerDefaults, + ResponseOverrides, ToolCall, ToolDefinition, } from "./types.js"; import { generateToolUseId, + extractOverrides, isTextResponse, isToolCallResponse, + isContentWithToolCallsResponse, isErrorResponse, flattenHeaders, getTestId, @@ -32,7 +35,11 @@ import type { Journal } from "./journal.js"; import type { Logger } from "./logger.js"; import { applyChaos } from "./chaos.js"; import { proxyAndRecord } from "./recorder.js"; -import { buildBedrockStreamTextEvents, buildBedrockStreamToolCallEvents } from "./bedrock.js"; +import { + buildBedrockStreamTextEvents, + buildBedrockStreamToolCallEvents, + buildBedrockStreamContentWithToolCallsEvents, +} from "./bedrock.js"; // ─── Converse request types ───────────────────────────────────────────────── @@ -60,6 +67,30 @@ interface ConverseRequest { toolConfig?: { tools: { toolSpec: ConverseToolSpec }[] }; } +// ─── Converse stop_reason mapping ────────────────────────────────────────── + +function converseStopReason( + overrideFinishReason: string | undefined, + defaultReason: string, +): string { + if (!overrideFinishReason) return defaultReason; + if (overrideFinishReason === "stop") return "end_turn"; + if (overrideFinishReason === "tool_calls") return "tool_use"; + if (overrideFinishReason === "length") return "max_tokens"; + return overrideFinishReason; +} + +function converseUsage(overrides?: ResponseOverrides): { + inputTokens: number; + outputTokens: number; + totalTokens: number; +} { + if (!overrides?.usage) return { inputTokens: 0, outputTokens: 0, totalTokens: 0 }; + const inputTokens = overrides.usage.input_tokens ?? overrides.usage.prompt_tokens ?? 0; + const outputTokens = overrides.usage.output_tokens ?? overrides.usage.completion_tokens ?? 0; + return { inputTokens, outputTokens, totalTokens: inputTokens + outputTokens }; +} + // ─── Input conversion: Converse → ChatCompletionRequest ───────────────────── export function converseToCompletionRequest( @@ -157,7 +188,11 @@ export function converseToCompletionRequest( // ─── Response builders ────────────────────────────────────────────────────── -function buildConverseTextResponse(content: string, reasoning?: string): object { +function buildConverseTextResponse( + content: string, + reasoning?: string, + overrides?: ResponseOverrides, +): object { const contentBlocks: object[] = []; if (reasoning) { contentBlocks.push({ @@ -173,12 +208,16 @@ function buildConverseTextResponse(content: string, reasoning?: string): object content: contentBlocks, }, }, - stopReason: "end_turn", - usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + stopReason: converseStopReason(overrides?.finishReason, "end_turn"), + usage: converseUsage(overrides), }; } -function buildConverseToolCallResponse(toolCalls: ToolCall[], logger: Logger): object { +function buildConverseToolCallResponse( + toolCalls: ToolCall[], + logger: Logger, + overrides?: ResponseOverrides, +): object { return { output: { message: { @@ -203,8 +242,53 @@ function buildConverseToolCallResponse(toolCalls: ToolCall[], logger: Logger): o }), }, }, - stopReason: "tool_use", - usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + stopReason: converseStopReason(overrides?.finishReason, "tool_use"), + usage: converseUsage(overrides), + }; +} + +function buildConverseContentWithToolCallsResponse( + content: string, + toolCalls: ToolCall[], + logger: Logger, + reasoning?: string, + overrides?: ResponseOverrides, +): object { + const contentBlocks: object[] = []; + if (reasoning) { + contentBlocks.push({ + reasoningContent: { reasoningText: { text: reasoning } }, + }); + } + contentBlocks.push({ text: content }); + for (const tc of toolCalls) { + let argsObj: unknown; + try { + argsObj = JSON.parse(tc.arguments || "{}"); + } catch { + logger.warn( + `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, + ); + argsObj = {}; + } + contentBlocks.push({ + toolUse: { + toolUseId: tc.id || generateToolUseId(), + name: tc.name, + input: argsObj, + }, + }); + } + + return { + output: { + message: { + role: "assistant", + content: contentBlocks, + }, + }, + stopReason: converseStopReason(overrides?.finishReason, "tool_use"), + usage: converseUsage(overrides), }; } @@ -298,7 +382,6 @@ export async function handleConverse( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -307,7 +390,7 @@ export async function handleConverse( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -317,13 +400,13 @@ export async function handleConverse( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: urlPath, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -367,12 +450,52 @@ export async function handleConverse( body: completionReq, response: { status, fixture }, }); - writeErrorResponse(res, status, JSON.stringify(response)); + const errBody = { + type: "error", + error: { + type: response.error.type || "invalid_request_error", + message: response.error.message, + }, + }; + writeErrorResponse(res, status, JSON.stringify(errBody)); + return; + } + + // Content + tool calls response + if (isContentWithToolCallsResponse(response)) { + if (response.webSearches?.length) { + logger.warn( + "webSearches in fixture response are not supported for Bedrock Converse API — ignoring", + ); + } + const overrides = extractOverrides(response); + journal.add({ + method: req.method ?? "POST", + path: urlPath, + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 200, fixture }, + }); + const body = buildConverseContentWithToolCallsResponse( + response.content, + response.toolCalls, + logger, + response.reasoning, + overrides, + ); + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify(body)); return; } // Text response if (isTextResponse(response)) { + if (response.webSearches?.length) { + logger.warn( + "webSearches in fixture response are not supported for Bedrock Converse API — ignoring", + ); + } + const overrides = extractOverrides(response); journal.add({ method: req.method ?? "POST", path: urlPath, @@ -380,7 +503,7 @@ export async function handleConverse( body: completionReq, response: { status: 200, fixture }, }); - const body = buildConverseTextResponse(response.content, response.reasoning); + const body = buildConverseTextResponse(response.content, response.reasoning, overrides); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(body)); return; @@ -388,6 +511,7 @@ export async function handleConverse( // Tool call response if (isToolCallResponse(response)) { + const overrides = extractOverrides(response); journal.add({ method: req.method ?? "POST", path: urlPath, @@ -395,7 +519,7 @@ export async function handleConverse( body: completionReq, response: { status: 200, fixture }, }); - const body = buildConverseToolCallResponse(response.toolCalls, logger); + const body = buildConverseToolCallResponse(response.toolCalls, logger, overrides); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(body)); return; @@ -509,7 +633,6 @@ export async function handleConverseStream( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -518,7 +641,7 @@ export async function handleConverseStream( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -528,13 +651,13 @@ export async function handleConverseStream( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: urlPath, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -580,12 +703,64 @@ export async function handleConverseStream( body: completionReq, response: { status, fixture }, }); - writeErrorResponse(res, status, JSON.stringify(response)); + const errBody = { + type: "error", + error: { + type: response.error.type || "invalid_request_error", + message: response.error.message, + }, + }; + writeErrorResponse(res, status, JSON.stringify(errBody)); + return; + } + + // Content + tool calls response — stream as Event Stream + if (isContentWithToolCallsResponse(response)) { + if (response.webSearches?.length) { + logger.warn( + "webSearches in fixture response are not supported for Bedrock Converse API — ignoring", + ); + } + const overrides = extractOverrides(response); + const journalEntry = journal.add({ + method: req.method ?? "POST", + path: urlPath, + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 200, fixture }, + }); + const events = buildBedrockStreamContentWithToolCallsEvents( + response.content, + response.toolCalls, + chunkSize, + logger, + response.reasoning, + overrides, + ); + const interruption = createInterruptionSignal(fixture); + const completed = await writeEventStream(res, events, { + latency, + streamingProfile: fixture.streamingProfile, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); return; } // Text response — stream as Event Stream if (isTextResponse(response)) { + if (response.webSearches?.length) { + logger.warn( + "webSearches in fixture response are not supported for Bedrock Converse API — ignoring", + ); + } + const overrides = extractOverrides(response); const journalEntry = journal.add({ method: req.method ?? "POST", path: urlPath, @@ -593,7 +768,12 @@ export async function handleConverseStream( body: completionReq, response: { status: 200, fixture }, }); - const events = buildBedrockStreamTextEvents(response.content, chunkSize, response.reasoning); + const events = buildBedrockStreamTextEvents( + response.content, + chunkSize, + response.reasoning, + overrides, + ); const interruption = createInterruptionSignal(fixture); const completed = await writeEventStream(res, events, { latency, @@ -612,6 +792,7 @@ export async function handleConverseStream( // Tool call response — stream as Event Stream if (isToolCallResponse(response)) { + const overrides = extractOverrides(response); const journalEntry = journal.add({ method: req.method ?? "POST", path: urlPath, @@ -619,7 +800,12 @@ export async function handleConverseStream( body: completionReq, response: { status: 200, fixture }, }); - const events = buildBedrockStreamToolCallEvents(response.toolCalls, chunkSize, logger); + const events = buildBedrockStreamToolCallEvents( + response.toolCalls, + chunkSize, + logger, + overrides, + ); const interruption = createInterruptionSignal(fixture); const completed = await writeEventStream(res, events, { latency, diff --git a/src/bedrock.ts b/src/bedrock.ts index 649d20d..02a3dc1 100644 --- a/src/bedrock.ts +++ b/src/bedrock.ts @@ -23,14 +23,17 @@ import type { ChatMessage, Fixture, HandlerDefaults, + ResponseOverrides, ToolCall, ToolDefinition, } from "./types.js"; import { generateMessageId, generateToolUseId, + extractOverrides, isTextResponse, isToolCallResponse, + isContentWithToolCallsResponse, isErrorResponse, flattenHeaders, getTestId, @@ -79,6 +82,30 @@ interface BedrockRequest { [key: string]: unknown; } +// ─── Bedrock stop_reason mapping ─────────────────────────────────────────── + +function bedrockStopReason( + overrideFinishReason: string | undefined, + defaultReason: string, +): string { + if (!overrideFinishReason) return defaultReason; + if (overrideFinishReason === "stop") return "end_turn"; + if (overrideFinishReason === "tool_calls") return "tool_use"; + if (overrideFinishReason === "length") return "max_tokens"; + return overrideFinishReason; +} + +function bedrockUsage(overrides?: ResponseOverrides): { + input_tokens: number; + output_tokens: number; +} { + if (!overrides?.usage) return { input_tokens: 0, output_tokens: 0 }; + return { + input_tokens: overrides.usage.input_tokens ?? overrides.usage.prompt_tokens ?? 0, + output_tokens: overrides.usage.output_tokens ?? overrides.usage.completion_tokens ?? 0, + }; +} + // ─── Input conversion: Bedrock → ChatCompletionRequest ────────────────────── function extractTextContent(content: string | BedrockContentBlock[]): string { @@ -199,7 +226,12 @@ export function bedrockToCompletionRequest( // ─── Response builders ────────────────────────────────────────────────────── -function buildBedrockTextResponse(content: string, model: string, reasoning?: string): object { +function buildBedrockTextResponse( + content: string, + model: string, + reasoning?: string, + overrides?: ResponseOverrides, +): object { const contentBlocks: object[] = []; if (reasoning) { contentBlocks.push({ type: "thinking", thinking: reasoning }); @@ -207,14 +239,14 @@ function buildBedrockTextResponse(content: string, model: string, reasoning?: st contentBlocks.push({ type: "text", text: content }); return { - id: generateMessageId(), + id: overrides?.id ?? generateMessageId(), type: "message", role: "assistant", content: contentBlocks, - model, - stop_reason: "end_turn", + model: overrides?.model ?? model, + stop_reason: bedrockStopReason(overrides?.finishReason, "end_turn"), stop_sequence: null, - usage: { input_tokens: 0, output_tokens: 0 }, + usage: bedrockUsage(overrides), }; } @@ -222,9 +254,10 @@ function buildBedrockToolCallResponse( toolCalls: ToolCall[], model: string, logger: Logger, + overrides?: ResponseOverrides, ): object { return { - id: generateMessageId(), + id: overrides?.id ?? generateMessageId(), type: "message", role: "assistant", content: toolCalls.map((tc) => { @@ -244,10 +277,10 @@ function buildBedrockToolCallResponse( input: argsObj, }; }), - model, - stop_reason: "tool_use", + model: overrides?.model ?? model, + stop_reason: bedrockStopReason(overrides?.finishReason, "tool_use"), stop_sequence: null, - usage: { input_tokens: 0, output_tokens: 0 }, + usage: bedrockUsage(overrides), }; } @@ -342,7 +375,6 @@ export async function handleBedrock( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -351,7 +383,7 @@ export async function handleBedrock( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -361,13 +393,13 @@ export async function handleBedrock( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: urlPath, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -423,8 +455,51 @@ export async function handleBedrock( return; } + // Content + tool calls response + if (isContentWithToolCallsResponse(response)) { + if (response.webSearches?.length) { + logger.warn("webSearches in fixture response are not supported for Bedrock API — ignoring"); + } + const overrides = extractOverrides(response); + journal.add({ + method: req.method ?? "POST", + path: urlPath, + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 200, fixture }, + }); + const textBody = buildBedrockTextResponse( + response.content, + completionReq.model, + response.reasoning, + overrides, + ); + const toolBody = buildBedrockToolCallResponse( + response.toolCalls, + completionReq.model, + logger, + overrides, + ); + // Merge: take the text response as base, append tool_use blocks, set stop_reason to tool_use + const merged = { + ...(textBody as Record), + content: [ + ...((textBody as Record).content as object[]), + ...((toolBody as Record).content as object[]), + ], + stop_reason: bedrockStopReason(overrides?.finishReason, "tool_use"), + }; + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify(merged)); + return; + } + // Text response if (isTextResponse(response)) { + if (response.webSearches?.length) { + logger.warn("webSearches in fixture response are not supported for Bedrock API — ignoring"); + } + const overrides = extractOverrides(response); journal.add({ method: req.method ?? "POST", path: urlPath, @@ -436,6 +511,7 @@ export async function handleBedrock( response.content, completionReq.model, response.reasoning, + overrides, ); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(body)); @@ -444,6 +520,7 @@ export async function handleBedrock( // Tool call response if (isToolCallResponse(response)) { + const overrides = extractOverrides(response); journal.add({ method: req.method ?? "POST", path: urlPath, @@ -451,7 +528,12 @@ export async function handleBedrock( body: completionReq, response: { status: 200, fixture }, }); - const body = buildBedrockToolCallResponse(response.toolCalls, completionReq.model, logger); + const body = buildBedrockToolCallResponse( + response.toolCalls, + completionReq.model, + logger, + overrides, + ); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(body)); return; @@ -483,12 +565,13 @@ export function buildBedrockStreamTextEvents( content: string, chunkSize: number, reasoning?: string, + overrides?: ResponseOverrides, ): Array<{ eventType: string; payload: object }> { const events: Array<{ eventType: string; payload: object }> = []; events.push({ eventType: "messageStart", - payload: { role: "assistant" }, + payload: { messageStart: { role: "assistant" } }, }); // Thinking block (emitted before text when reasoning is present) @@ -496,7 +579,13 @@ export function buildBedrockStreamTextEvents( const blockIndex = 0; events.push({ eventType: "contentBlockStart", - payload: { contentBlockIndex: blockIndex, start: { type: "thinking" } }, + payload: { + contentBlockIndex: blockIndex, + contentBlockStart: { + contentBlockIndex: blockIndex, + start: { type: "thinking" }, + }, + }, }); for (let i = 0; i < reasoning.length; i += chunkSize) { @@ -505,7 +594,10 @@ export function buildBedrockStreamTextEvents( eventType: "contentBlockDelta", payload: { contentBlockIndex: blockIndex, - delta: { type: "thinking_delta", thinking: slice }, + contentBlockDelta: { + contentBlockIndex: blockIndex, + delta: { type: "thinking_delta", thinking: slice }, + }, }, }); } @@ -521,7 +613,13 @@ export function buildBedrockStreamTextEvents( events.push({ eventType: "contentBlockStart", - payload: { contentBlockIndex: textBlockIndex, start: {} }, + payload: { + contentBlockIndex: textBlockIndex, + contentBlockStart: { + contentBlockIndex: textBlockIndex, + start: { type: "text" }, + }, + }, }); for (let i = 0; i < content.length; i += chunkSize) { @@ -530,7 +628,10 @@ export function buildBedrockStreamTextEvents( eventType: "contentBlockDelta", payload: { contentBlockIndex: textBlockIndex, - delta: { type: "text_delta", text: slice }, + contentBlockDelta: { + contentBlockIndex: textBlockIndex, + delta: { type: "text_delta", text: slice }, + }, }, }); } @@ -542,7 +643,142 @@ export function buildBedrockStreamTextEvents( events.push({ eventType: "messageStop", - payload: { stopReason: "end_turn" }, + payload: { stopReason: bedrockStopReason(overrides?.finishReason, "end_turn") }, + }); + + return events; +} + +export function buildBedrockStreamContentWithToolCallsEvents( + content: string, + toolCalls: ToolCall[], + chunkSize: number, + logger: Logger, + reasoning?: string, + overrides?: ResponseOverrides, +): Array<{ eventType: string; payload: object }> { + const events: Array<{ eventType: string; payload: object }> = []; + + events.push({ + eventType: "messageStart", + payload: { messageStart: { role: "assistant" } }, + }); + + let blockIndex = 0; + + // Thinking block (emitted before text when reasoning is present) + if (reasoning) { + events.push({ + eventType: "contentBlockStart", + payload: { + contentBlockIndex: blockIndex, + contentBlockStart: { + contentBlockIndex: blockIndex, + start: { type: "thinking" }, + }, + }, + }); + for (let i = 0; i < reasoning.length; i += chunkSize) { + const slice = reasoning.slice(i, i + chunkSize); + events.push({ + eventType: "contentBlockDelta", + payload: { + contentBlockIndex: blockIndex, + contentBlockDelta: { + contentBlockIndex: blockIndex, + delta: { type: "thinking_delta", thinking: slice }, + }, + }, + }); + } + events.push({ + eventType: "contentBlockStop", + payload: { contentBlockIndex: blockIndex }, + }); + blockIndex++; + } + + // Text block + events.push({ + eventType: "contentBlockStart", + payload: { + contentBlockIndex: blockIndex, + contentBlockStart: { + contentBlockIndex: blockIndex, + start: { type: "text" }, + }, + }, + }); + for (let i = 0; i < content.length; i += chunkSize) { + const slice = content.slice(i, i + chunkSize); + events.push({ + eventType: "contentBlockDelta", + payload: { + contentBlockIndex: blockIndex, + contentBlockDelta: { + contentBlockIndex: blockIndex, + delta: { type: "text_delta", text: slice }, + }, + }, + }); + } + events.push({ + eventType: "contentBlockStop", + payload: { contentBlockIndex: blockIndex }, + }); + blockIndex++; + + // Tool call blocks + for (let tcIdx = 0; tcIdx < toolCalls.length; tcIdx++) { + const tc = toolCalls[tcIdx]; + const toolUseId = tc.id || generateToolUseId(); + const currentBlock = blockIndex + tcIdx; + + events.push({ + eventType: "contentBlockStart", + payload: { + contentBlockIndex: currentBlock, + contentBlockStart: { + contentBlockIndex: currentBlock, + start: { toolUse: { toolUseId, name: tc.name } }, + }, + }, + }); + + let argsStr: string; + try { + const parsed = JSON.parse(tc.arguments || "{}"); + argsStr = JSON.stringify(parsed); + } catch { + logger.warn( + `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, + ); + argsStr = "{}"; + } + + for (let i = 0; i < argsStr.length; i += chunkSize) { + const slice = argsStr.slice(i, i + chunkSize); + events.push({ + eventType: "contentBlockDelta", + payload: { + contentBlockIndex: currentBlock, + contentBlockDelta: { + contentBlockIndex: currentBlock, + delta: { toolUse: { input: slice } }, + }, + }, + }); + } + + events.push({ + eventType: "contentBlockStop", + payload: { contentBlockIndex: currentBlock }, + }); + } + + events.push({ + eventType: "messageStop", + payload: { stopReason: bedrockStopReason(overrides?.finishReason, "tool_use") }, }); return events; @@ -552,12 +788,13 @@ export function buildBedrockStreamToolCallEvents( toolCalls: ToolCall[], chunkSize: number, logger: Logger, + overrides?: ResponseOverrides, ): Array<{ eventType: string; payload: object }> { const events: Array<{ eventType: string; payload: object }> = []; events.push({ eventType: "messageStart", - payload: { role: "assistant" }, + payload: { messageStart: { role: "assistant" } }, }); for (let tcIdx = 0; tcIdx < toolCalls.length; tcIdx++) { @@ -568,8 +805,11 @@ export function buildBedrockStreamToolCallEvents( eventType: "contentBlockStart", payload: { contentBlockIndex: tcIdx, - start: { - toolUse: { toolUseId, name: tc.name }, + contentBlockStart: { + contentBlockIndex: tcIdx, + start: { + toolUse: { toolUseId, name: tc.name }, + }, }, }, }); @@ -591,7 +831,10 @@ export function buildBedrockStreamToolCallEvents( eventType: "contentBlockDelta", payload: { contentBlockIndex: tcIdx, - delta: { type: "input_json_delta", inputJSON: slice }, + contentBlockDelta: { + contentBlockIndex: tcIdx, + delta: { toolUse: { input: slice } }, + }, }, }); } @@ -604,7 +847,7 @@ export function buildBedrockStreamToolCallEvents( events.push({ eventType: "messageStop", - payload: { stopReason: "tool_use" }, + payload: { stopReason: bedrockStopReason(overrides?.finishReason, "tool_use") }, }); return events; @@ -700,7 +943,6 @@ export async function handleBedrockStream( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -709,7 +951,7 @@ export async function handleBedrockStream( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -719,13 +961,13 @@ export async function handleBedrockStream( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: urlPath, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -771,12 +1013,61 @@ export async function handleBedrockStream( body: completionReq, response: { status, fixture }, }); - writeErrorResponse(res, status, JSON.stringify(response)); + // Anthropic-style error format (Bedrock uses Claude): { type: "error", error: { type, message } } + const anthropicError = { + type: "error", + error: { + type: response.error.type ?? "api_error", + message: response.error.message, + }, + }; + writeErrorResponse(res, status, JSON.stringify(anthropicError)); + return; + } + + // Content + tool calls response — stream as Event Stream + if (isContentWithToolCallsResponse(response)) { + if (response.webSearches?.length) { + logger.warn("webSearches in fixture response are not supported for Bedrock API — ignoring"); + } + const overrides = extractOverrides(response); + const journalEntry = journal.add({ + method: req.method ?? "POST", + path: urlPath, + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 200, fixture }, + }); + const events = buildBedrockStreamContentWithToolCallsEvents( + response.content, + response.toolCalls, + chunkSize, + logger, + response.reasoning, + overrides, + ); + const interruption = createInterruptionSignal(fixture); + const completed = await writeEventStream(res, events, { + latency, + streamingProfile: fixture.streamingProfile, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); return; } // Text response — stream as Event Stream if (isTextResponse(response)) { + if (response.webSearches?.length) { + logger.warn("webSearches in fixture response are not supported for Bedrock API — ignoring"); + } + const overrides = extractOverrides(response); const journalEntry = journal.add({ method: req.method ?? "POST", path: urlPath, @@ -784,7 +1075,12 @@ export async function handleBedrockStream( body: completionReq, response: { status: 200, fixture }, }); - const events = buildBedrockStreamTextEvents(response.content, chunkSize, response.reasoning); + const events = buildBedrockStreamTextEvents( + response.content, + chunkSize, + response.reasoning, + overrides, + ); const interruption = createInterruptionSignal(fixture); const completed = await writeEventStream(res, events, { latency, @@ -803,6 +1099,7 @@ export async function handleBedrockStream( // Tool call response — stream as Event Stream if (isToolCallResponse(response)) { + const overrides = extractOverrides(response); const journalEntry = journal.add({ method: req.method ?? "POST", path: urlPath, @@ -810,7 +1107,12 @@ export async function handleBedrockStream( body: completionReq, response: { status: 200, fixture }, }); - const events = buildBedrockStreamToolCallEvents(response.toolCalls, chunkSize, logger); + const events = buildBedrockStreamToolCallEvents( + response.toolCalls, + chunkSize, + logger, + overrides, + ); const interruption = createInterruptionSignal(fixture); const completed = await writeEventStream(res, events, { latency, diff --git a/src/cohere.ts b/src/cohere.ts index bd66f11..afc9901 100644 --- a/src/cohere.ts +++ b/src/cohere.ts @@ -15,6 +15,7 @@ import type { ChatMessage, Fixture, HandlerDefaults, + ResponseOverrides, StreamingProfile, ToolCall, ToolDefinition, @@ -22,8 +23,10 @@ import type { import { generateMessageId, generateToolCallId, + extractOverrides, isTextResponse, isToolCallResponse, + isContentWithToolCallsResponse, isErrorResponse, flattenHeaders, getTestId, @@ -38,10 +41,20 @@ import { proxyAndRecord } from "./recorder.js"; // ─── Cohere v2 Chat request types ─────────────────────────────────────────── +interface CohereToolCallDef { + id?: string; + type: string; + function: { + name: string; + arguments: string; + }; +} + interface CohereMessage { role: "user" | "assistant" | "system" | "tool"; content: string; tool_call_id?: string; + tool_calls?: CohereToolCallDef[]; } interface CohereToolDef { @@ -75,6 +88,34 @@ const ZERO_USAGE = { tokens: { input_tokens: 0, output_tokens: 0 }, }; +// ─── Cohere finish reason / usage mapping ────────────────────────────────── + +function cohereFinishReason( + overrideFinishReason: string | undefined, + defaultReason: string, +): string { + if (!overrideFinishReason) return defaultReason; + if (overrideFinishReason === "stop") return "COMPLETE"; + if (overrideFinishReason === "tool_calls") return "TOOL_CALL"; + if (overrideFinishReason === "length") return "MAX_TOKENS"; + return overrideFinishReason; +} + +function cohereUsage(overrides?: ResponseOverrides): typeof ZERO_USAGE { + if (!overrides?.usage) return ZERO_USAGE; + const inputTokens = overrides.usage.input_tokens ?? overrides.usage.prompt_tokens ?? 0; + const outputTokens = overrides.usage.output_tokens ?? overrides.usage.completion_tokens ?? 0; + return { + billed_units: { + input_tokens: inputTokens, + output_tokens: outputTokens, + search_units: 0, + classifications: 0, + }, + tokens: { input_tokens: inputTokens, output_tokens: outputTokens }, + }; +} + // ─── Input conversion: Cohere → ChatCompletionRequest ─────────────────────── export function cohereToCompletionRequest(req: CohereRequest): ChatCompletionRequest { @@ -86,7 +127,22 @@ export function cohereToCompletionRequest(req: CohereRequest): ChatCompletionReq } else if (msg.role === "user") { messages.push({ role: "user", content: msg.content }); } else if (msg.role === "assistant") { - messages.push({ role: "assistant", content: msg.content }); + if (msg.tool_calls && msg.tool_calls.length > 0) { + messages.push({ + role: "assistant", + content: msg.content || null, + tool_calls: msg.tool_calls.map((tc) => ({ + id: tc.id ?? generateToolCallId(), + type: "function" as const, + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }); + } else { + messages.push({ role: "assistant", content: msg.content }); + } } else if (msg.role === "tool") { messages.push({ role: "tool", @@ -114,51 +170,69 @@ export function cohereToCompletionRequest(req: CohereRequest): ChatCompletionReq messages, stream: req.stream, tools, + ...(req.response_format && { response_format: req.response_format }), }; } // ─── Response building: fixture → Cohere v2 Chat format ───────────────────── // Non-streaming text response -function buildCohereTextResponse(content: string): object { +function buildCohereTextResponse( + content: string, + reasoning?: string, + overrides?: ResponseOverrides, +): object { + const contentBlocks: { type: string; text: string }[] = []; + if (reasoning) { + contentBlocks.push({ type: "text", text: reasoning }); + } + contentBlocks.push({ type: "text", text: content }); + return { - id: generateMessageId(), - finish_reason: "COMPLETE", + id: overrides?.id ?? generateMessageId(), + finish_reason: cohereFinishReason(overrides?.finishReason, "COMPLETE"), message: { role: "assistant", - content: [{ type: "text", text: content }], + content: contentBlocks, tool_calls: [], tool_plan: "", citations: [], }, - usage: ZERO_USAGE, + usage: cohereUsage(overrides), }; } // Non-streaming tool call response -function buildCohereToolCallResponse(toolCalls: ToolCall[], logger: Logger): object { +function buildCohereToolCallResponse( + toolCalls: ToolCall[], + logger: Logger, + overrides?: ResponseOverrides, +): object { const cohereCalls = toolCalls.map((tc) => { // Validate arguments JSON + let argsJson: string; try { JSON.parse(tc.arguments || "{}"); + argsJson = tc.arguments || "{}"; } catch { logger.warn( `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, ); + argsJson = "{}"; } return { id: tc.id || generateToolCallId(), type: "function", function: { name: tc.name, - arguments: tc.arguments || "{}", + arguments: argsJson, }, }; }); return { - id: generateMessageId(), - finish_reason: "TOOL_CALL", + id: overrides?.id ?? generateMessageId(), + finish_reason: cohereFinishReason(overrides?.finishReason, "TOOL_CALL"), message: { role: "assistant", content: [], @@ -166,14 +240,68 @@ function buildCohereToolCallResponse(toolCalls: ToolCall[], logger: Logger): obj tool_plan: "", citations: [], }, - usage: ZERO_USAGE, + usage: cohereUsage(overrides), + }; +} + +// Non-streaming content + tool calls response +function buildCohereContentWithToolCallsResponse( + content: string, + toolCalls: ToolCall[], + logger: Logger, + reasoning?: string, + overrides?: ResponseOverrides, +): object { + const cohereCalls = toolCalls.map((tc) => { + let argsJson: string; + try { + JSON.parse(tc.arguments || "{}"); + argsJson = tc.arguments || "{}"; + } catch { + logger.warn( + `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, + ); + argsJson = "{}"; + } + return { + id: tc.id || generateToolCallId(), + type: "function", + function: { + name: tc.name, + arguments: argsJson, + }, + }; + }); + + const contentBlocks: { type: string; text: string }[] = []; + if (reasoning) { + contentBlocks.push({ type: "text", text: reasoning }); + } + contentBlocks.push({ type: "text", text: content }); + + return { + id: overrides?.id ?? generateMessageId(), + finish_reason: cohereFinishReason(overrides?.finishReason, "TOOL_CALL"), + message: { + role: "assistant", + content: contentBlocks, + tool_calls: cohereCalls, + tool_plan: "", + citations: [], + }, + usage: cohereUsage(overrides), }; } // ─── Streaming event builders ─────────────────────────────────────────────── -function buildCohereTextStreamEvents(content: string, chunkSize: number): CohereSSEEvent[] { - const msgId = generateMessageId(); +function buildCohereTextStreamEvents( + content: string, + chunkSize: number, + reasoning?: string, + overrides?: ResponseOverrides, +): CohereSSEEvent[] { + const msgId = overrides?.id ?? generateMessageId(); const events: CohereSSEEvent[] = []; // message-start @@ -191,10 +319,31 @@ function buildCohereTextStreamEvents(content: string, chunkSize: number): Cohere }, }); + let contentIndex = 0; + + // Reasoning as a text block before main content (Cohere has no native reasoning type) + if (reasoning) { + events.push({ + type: "content-start", + index: contentIndex, + delta: { message: { content: { type: "text" } } }, + }); + for (let i = 0; i < reasoning.length; i += chunkSize) { + const slice = reasoning.slice(i, i + chunkSize); + events.push({ + type: "content-delta", + index: contentIndex, + delta: { message: { content: { type: "text", text: slice } } }, + }); + } + events.push({ type: "content-end", index: contentIndex }); + contentIndex++; + } + // content-start (type: "text" only, no text field) events.push({ type: "content-start", - index: 0, + index: contentIndex, delta: { message: { content: { type: "text" }, @@ -207,7 +356,7 @@ function buildCohereTextStreamEvents(content: string, chunkSize: number): Cohere const slice = content.slice(i, i + chunkSize); events.push({ type: "content-delta", - index: 0, + index: contentIndex, delta: { message: { content: { type: "text", text: slice }, @@ -219,15 +368,15 @@ function buildCohereTextStreamEvents(content: string, chunkSize: number): Cohere // content-end events.push({ type: "content-end", - index: 0, + index: contentIndex, }); // message-end events.push({ type: "message-end", delta: { - finish_reason: "COMPLETE", - usage: ZERO_USAGE, + finish_reason: cohereFinishReason(overrides?.finishReason, "COMPLETE"), + usage: cohereUsage(overrides), }, }); @@ -238,8 +387,9 @@ function buildCohereToolCallStreamEvents( toolCalls: ToolCall[], chunkSize: number, logger: Logger, + overrides?: ResponseOverrides, ): CohereSSEEvent[] { - const msgId = generateMessageId(); + const msgId = overrides?.id ?? generateMessageId(); const events: CohereSSEEvent[] = []; // message-start @@ -330,8 +480,167 @@ function buildCohereToolCallStreamEvents( events.push({ type: "message-end", delta: { - finish_reason: "TOOL_CALL", - usage: ZERO_USAGE, + finish_reason: cohereFinishReason(overrides?.finishReason, "TOOL_CALL"), + usage: cohereUsage(overrides), + }, + }); + + return events; +} + +function buildCohereContentWithToolCallsStreamEvents( + content: string, + toolCalls: ToolCall[], + chunkSize: number, + logger: Logger, + reasoning?: string, + overrides?: ResponseOverrides, +): CohereSSEEvent[] { + const msgId = overrides?.id ?? generateMessageId(); + const events: CohereSSEEvent[] = []; + + // message-start + events.push({ + id: msgId, + type: "message-start", + delta: { + message: { + role: "assistant", + content: [], + tool_plan: "", + tool_calls: [], + citations: [], + }, + }, + }); + + let contentIndex = 0; + + // Reasoning as a text block before main content + if (reasoning) { + events.push({ + type: "content-start", + index: contentIndex, + delta: { message: { content: { type: "text" } } }, + }); + for (let i = 0; i < reasoning.length; i += chunkSize) { + const slice = reasoning.slice(i, i + chunkSize); + events.push({ + type: "content-delta", + index: contentIndex, + delta: { message: { content: { type: "text", text: slice } } }, + }); + } + events.push({ type: "content-end", index: contentIndex }); + contentIndex++; + } + + // content-start (type: "text" only, no text field) + events.push({ + type: "content-start", + index: contentIndex, + delta: { + message: { + content: { type: "text" }, + }, + }, + }); + + // content-delta — text chunks + for (let i = 0; i < content.length; i += chunkSize) { + const slice = content.slice(i, i + chunkSize); + events.push({ + type: "content-delta", + index: contentIndex, + delta: { + message: { + content: { type: "text", text: slice }, + }, + }, + }); + } + + // content-end + events.push({ + type: "content-end", + index: contentIndex, + }); + + // tool-plan-delta + events.push({ + type: "tool-plan-delta", + delta: { + message: { + tool_plan: "I will use the requested tool.", + }, + }, + }); + + // Tool call events + for (let idx = 0; idx < toolCalls.length; idx++) { + const tc = toolCalls[idx]; + const callId = tc.id || generateToolCallId(); + + let argsJson: string; + try { + JSON.parse(tc.arguments || "{}"); + argsJson = tc.arguments || "{}"; + } catch { + logger.warn( + `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, + ); + argsJson = "{}"; + } + + // tool-call-start + events.push({ + type: "tool-call-start", + index: idx, + delta: { + message: { + tool_calls: { + id: callId, + type: "function", + function: { + name: tc.name, + arguments: "", + }, + }, + }, + }, + }); + + // tool-call-delta — chunked arguments + for (let i = 0; i < argsJson.length; i += chunkSize) { + const slice = argsJson.slice(i, i + chunkSize); + events.push({ + type: "tool-call-delta", + index: idx, + delta: { + message: { + tool_calls: { + function: { + arguments: slice, + }, + }, + }, + }, + }); + } + + // tool-call-end + events.push({ + type: "tool-call-end", + index: idx, + }); + } + + // message-end + events.push({ + type: "message-end", + delta: { + finish_reason: cohereFinishReason(overrides?.finishReason, "TOOL_CALL"), + usage: cohereUsage(overrides), }, }); @@ -492,7 +801,6 @@ export async function handleCohere( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -501,7 +809,7 @@ export async function handleCohere( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -511,13 +819,13 @@ export async function handleCohere( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: req.url ?? "/v2/chat", headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -569,8 +877,65 @@ export async function handleCohere( return; } + // Content + tool calls response (must be checked before text/tool-only branches) + if (isContentWithToolCallsResponse(response)) { + if (response.webSearches?.length) { + logger.warn( + "webSearches in fixture response are not supported for Cohere v2 Chat API — ignoring", + ); + } + const overrides = extractOverrides(response); + const journalEntry = journal.add({ + method: req.method ?? "POST", + path: req.url ?? "/v2/chat", + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 200, fixture }, + }); + if (cohereReq.stream !== true) { + const body = buildCohereContentWithToolCallsResponse( + response.content, + response.toolCalls, + logger, + response.reasoning, + overrides, + ); + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify(body)); + } else { + const events = buildCohereContentWithToolCallsStreamEvents( + response.content, + response.toolCalls, + chunkSize, + logger, + response.reasoning, + overrides, + ); + const interruption = createInterruptionSignal(fixture); + const completed = await writeCohereSSEStream(res, events, { + latency, + streamingProfile: fixture.streamingProfile, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); + } + return; + } + // Text response if (isTextResponse(response)) { + if (response.webSearches?.length) { + logger.warn( + "webSearches in fixture response are not supported for Cohere v2 Chat API — ignoring", + ); + } + const overrides = extractOverrides(response); const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? "/v2/chat", @@ -579,11 +944,16 @@ export async function handleCohere( response: { status: 200, fixture }, }); if (cohereReq.stream !== true) { - const body = buildCohereTextResponse(response.content); + const body = buildCohereTextResponse(response.content, response.reasoning, overrides); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(body)); } else { - const events = buildCohereTextStreamEvents(response.content, chunkSize); + const events = buildCohereTextStreamEvents( + response.content, + chunkSize, + response.reasoning, + overrides, + ); const interruption = createInterruptionSignal(fixture); const completed = await writeCohereSSEStream(res, events, { latency, @@ -603,6 +973,7 @@ export async function handleCohere( // Tool call response if (isToolCallResponse(response)) { + const overrides = extractOverrides(response); const journalEntry = journal.add({ method: req.method ?? "POST", path: req.url ?? "/v2/chat", @@ -611,11 +982,16 @@ export async function handleCohere( response: { status: 200, fixture }, }); if (cohereReq.stream !== true) { - const body = buildCohereToolCallResponse(response.toolCalls, logger); + const body = buildCohereToolCallResponse(response.toolCalls, logger, overrides); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(body)); } else { - const events = buildCohereToolCallStreamEvents(response.toolCalls, chunkSize, logger); + const events = buildCohereToolCallStreamEvents( + response.toolCalls, + chunkSize, + logger, + overrides, + ); const interruption = createInterruptionSignal(fixture); const completed = await writeCohereSSEStream(res, events, { latency, diff --git a/src/embeddings.ts b/src/embeddings.ts index 8da7f34..1291322 100644 --- a/src/embeddings.ts +++ b/src/embeddings.ts @@ -1,5 +1,5 @@ /** - * OpenAI Embeddings API support for LLMock. + * OpenAI Embeddings API support for aimock. * * Handles POST /v1/embeddings requests. Matches fixtures using the `inputText` * field, and falls back to generating a deterministic embedding from the input @@ -7,7 +7,12 @@ */ import type * as http from "node:http"; -import type { ChatCompletionRequest, Fixture, HandlerDefaults } from "./types.js"; +import type { + ChatCompletionRequest, + Fixture, + HandlerDefaults, + RecordProviderKey, +} from "./types.js"; import { isEmbeddingResponse, isErrorResponse, @@ -42,6 +47,7 @@ export async function handleEmbeddings( journal: Journal, defaults: HandlerDefaults, setCorsHeaders: (res: http.ServerResponse) => void, + providerKey: RecordProviderKey = "openai", ): Promise { const { logger } = defaults; setCorsHeaders(res); @@ -71,6 +77,28 @@ export async function handleEmbeddings( return; } + // Validate required input parameter + if (embeddingReq.input === undefined || embeddingReq.input === null) { + journal.add({ + method: req.method ?? "POST", + path: req.url ?? "/v1/embeddings", + headers: flattenHeaders(req.headers), + body: null, + response: { status: 400, fixture: null }, + }); + writeErrorResponse( + res, + 400, + JSON.stringify({ + error: { + message: "Missing required parameter: 'input'", + type: "invalid_request_error", + }, + }), + ); + return; + } + // Normalize input to array of strings const inputs: string[] = Array.isArray(embeddingReq.input) ? embeddingReq.input @@ -113,7 +141,6 @@ export async function handleEmbeddings( headers: flattenHeaders(req.headers), body: syntheticReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -177,23 +204,23 @@ export async function handleEmbeddings( // No fixture match — try record-and-replay proxy if configured if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, syntheticReq, - "openai", + providerKey, req.url ?? "/v1/embeddings", fixtures, defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/embeddings", headers: flattenHeaders(req.headers), body: syntheticReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } diff --git a/src/gemini.ts b/src/gemini.ts index 08cd4b9..232ab49 100644 --- a/src/gemini.ts +++ b/src/gemini.ts @@ -93,25 +93,25 @@ export function geminiToCompletionRequest( } if (req.contents) { + let callCounter = 0; for (const content of req.contents) { const role = content.role ?? "user"; if (role === "user") { // Check for functionResponse parts const funcResponses = content.parts.filter((p) => p.functionResponse); - const textParts = content.parts.filter((p) => p.text !== undefined); + const textParts = content.parts.filter((p) => p.text !== undefined && !p.thought); if (funcResponses.length > 0) { // functionResponse → tool message - for (let i = 0; i < funcResponses.length; i++) { - const part = funcResponses[i]; + for (const part of funcResponses) { messages.push({ role: "tool", content: typeof part.functionResponse!.response === "string" ? part.functionResponse!.response : JSON.stringify(part.functionResponse!.response), - tool_call_id: `call_gemini_${part.functionResponse!.name}_${i}`, + tool_call_id: `call_gemini_${part.functionResponse!.name}_${callCounter++}`, }); } // Any text parts alongside → user message @@ -129,18 +129,19 @@ export function geminiToCompletionRequest( } else if (role === "model") { // Check for functionCall parts const funcCalls = content.parts.filter((p) => p.functionCall); - const textParts = content.parts.filter((p) => p.text !== undefined); + const textParts = content.parts.filter((p) => p.text !== undefined && !p.thought); if (funcCalls.length > 0) { + const text = textParts.map((p) => p.text!).join(""); messages.push({ role: "assistant", - content: null, - tool_calls: funcCalls.map((p, i) => ({ - id: `call_gemini_${p.functionCall!.name}_${i}`, + content: text || null, + tool_calls: funcCalls.map((fc) => ({ + id: `call_gemini_${fc.functionCall!.name}_${callCounter++}`, type: "function" as const, function: { - name: p.functionCall!.name, - arguments: JSON.stringify(p.functionCall!.args), + name: fc.functionCall!.name, + arguments: JSON.stringify(fc.functionCall!.args ?? {}), }, })), }); @@ -149,6 +150,9 @@ export function geminiToCompletionRequest( messages.push({ role: "assistant", content: text }); } } + // Unrecognized roles (not "user" or "model") are silently dropped. + // Gemini only defines "user" and "model"; any other value indicates + // a malformed request or an unsupported future role. } } @@ -562,7 +566,6 @@ export async function handleGemini( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -571,7 +574,7 @@ export async function handleGemini( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -581,13 +584,13 @@ export async function handleGemini( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } diff --git a/src/images.ts b/src/images.ts index e1f8ef9..fe20f8b 100644 --- a/src/images.ts +++ b/src/images.ts @@ -77,6 +77,24 @@ export async function handleImages( return; } + if (!prompt) { + journal.add({ + method, + path, + headers: flattenHeaders(req.headers), + body: null, + response: { status: 400, fixture: null }, + }); + writeErrorResponse( + res, + 400, + JSON.stringify({ + error: { message: "Missing required parameter: 'prompt'", type: "invalid_request_error" }, + }), + ); + return; + } + const syntheticReq = buildSyntheticRequest(model, prompt); const testId = getTestId(req); const fixture = matchFixture( @@ -98,7 +116,6 @@ export async function handleImages( req.headers, journal, { method, path, headers: flattenHeaders(req.headers), body: syntheticReq }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -107,7 +124,7 @@ export async function handleImages( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, syntheticReq, @@ -117,13 +134,13 @@ export async function handleImages( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method, path, headers: flattenHeaders(req.headers), body: syntheticReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } diff --git a/src/messages.ts b/src/messages.ts index 8321957..9b61264 100644 --- a/src/messages.ts +++ b/src/messages.ts @@ -752,7 +752,6 @@ export async function handleMessages( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -761,7 +760,7 @@ export async function handleMessages( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -771,13 +770,13 @@ export async function handleMessages( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/messages", headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -893,7 +892,7 @@ export async function handleMessages( // Text response if (isTextResponse(response)) { if (response.webSearches?.length) { - defaults.logger.warn( + logger.warn( "webSearches in fixture response are not supported for Claude Messages API — ignoring", ); } diff --git a/src/ollama.ts b/src/ollama.ts index 9d169b4..698567f 100644 --- a/src/ollama.ts +++ b/src/ollama.ts @@ -24,6 +24,7 @@ import type { import { isTextResponse, isToolCallResponse, + isContentWithToolCallsResponse, isErrorResponse, flattenHeaders, getTestId, @@ -108,7 +109,7 @@ export function ollamaToCompletionRequest(req: OllamaRequest): ChatCompletionReq return { model: req.model, messages, - stream: req.stream, + stream: req.stream ?? true, temperature: req.options?.temperature, max_tokens: req.options?.num_predict, tools, @@ -119,7 +120,7 @@ function ollamaGenerateToCompletionRequest(req: OllamaGenerateRequest): ChatComp return { model: req.model, messages: [{ role: "user", content: req.prompt }], - stream: req.stream, + stream: req.stream ?? true, temperature: req.options?.temperature, max_tokens: req.options?.num_predict, }; @@ -261,6 +262,103 @@ function buildOllamaChatToolCallResponse( }; } +// ─── Response builders: /api/chat — content + tool calls ──────────────────── + +function buildOllamaChatContentWithToolCallsChunks( + content: string, + toolCalls: ToolCall[], + model: string, + chunkSize: number, + logger: Logger, +): object[] { + const chunks: object[] = []; + + // Content chunks first + for (let i = 0; i < content.length; i += chunkSize) { + const slice = content.slice(i, i + chunkSize); + chunks.push({ + model, + message: { role: "assistant", content: slice }, + done: false, + }); + } + + // Tool calls in a single chunk (same as tool-call-only path) + const ollamaToolCalls = toolCalls.map((tc) => { + let argsObj: unknown; + try { + argsObj = JSON.parse(tc.arguments || "{}"); + } catch { + logger.warn( + `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, + ); + argsObj = {}; + } + return { + function: { + name: tc.name, + arguments: argsObj, + }, + }; + }); + + chunks.push({ + model, + message: { + role: "assistant", + content: "", + tool_calls: ollamaToolCalls, + }, + done: false, + }); + + // Final chunk + chunks.push({ + model, + message: { role: "assistant", content: "" }, + done: true, + ...DURATION_FIELDS, + }); + + return chunks; +} + +function buildOllamaChatContentWithToolCallsResponse( + content: string, + toolCalls: ToolCall[], + model: string, + logger: Logger, +): object { + const ollamaToolCalls = toolCalls.map((tc) => { + let argsObj: unknown; + try { + argsObj = JSON.parse(tc.arguments || "{}"); + } catch { + logger.warn( + `Malformed JSON in fixture tool call arguments for "${tc.name}": ${tc.arguments}`, + ); + argsObj = {}; + } + return { + function: { + name: tc.name, + arguments: argsObj, + }, + }; + }); + + return { + model, + message: { + role: "assistant", + content, + tool_calls: ollamaToolCalls, + }, + done: true, + ...DURATION_FIELDS, + }; +} + // ─── Response builders: /api/generate ──────────────────────────────────────── function buildOllamaGenerateTextChunks( @@ -415,7 +513,6 @@ export async function handleOllama( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -424,7 +521,7 @@ export async function handleOllama( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -434,13 +531,13 @@ export async function handleOllama( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: urlPath, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -493,6 +590,49 @@ export async function handleOllama( return; } + // Content + tool calls response (must be checked before text/tool-only branches) + if (isContentWithToolCallsResponse(response)) { + const journalEntry = journal.add({ + method: req.method ?? "POST", + path: urlPath, + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 200, fixture }, + }); + if (!streaming) { + const body = buildOllamaChatContentWithToolCallsResponse( + response.content, + response.toolCalls, + completionReq.model, + logger, + ); + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify(body)); + } else { + const chunks = buildOllamaChatContentWithToolCallsChunks( + response.content, + response.toolCalls, + completionReq.model, + chunkSize, + logger, + ); + const interruption = createInterruptionSignal(fixture); + const completed = await writeNDJSONStream(res, chunks, { + latency, + streamingProfile: fixture.streamingProfile, + signal: interruption?.signal, + onChunkSent: interruption?.tick, + }); + if (!completed) { + if (!res.writableEnded) res.destroy(); + journalEntry.response.interrupted = true; + journalEntry.response.interruptReason = interruption?.reason(); + } + interruption?.cleanup(); + } + return; + } + // Text response if (isTextResponse(response)) { const journalEntry = journal.add({ @@ -675,7 +815,6 @@ export async function handleOllamaGenerate( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -684,7 +823,7 @@ export async function handleOllamaGenerate( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -694,13 +833,13 @@ export async function handleOllamaGenerate( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: urlPath, headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -794,7 +933,29 @@ export async function handleOllamaGenerate( return; } - // Tool call responses not supported for /api/generate — fall through to error + // Tool call fixtures matched but not supported on /api/generate + if (isToolCallResponse(response) || isContentWithToolCallsResponse(response)) { + journal.add({ + method: req.method ?? "POST", + path: urlPath, + headers: flattenHeaders(req.headers), + body: completionReq, + response: { status: 400, fixture }, + }); + writeErrorResponse( + res, + 400, + JSON.stringify({ + error: { + message: "Tool call fixtures are not supported on /api/generate — use /api/chat instead", + type: "invalid_request_error", + }, + }), + ); + return; + } + + // Unknown response type journal.add({ method: req.method ?? "POST", path: urlPath, diff --git a/src/responses.ts b/src/responses.ts index 9e748a1..f7c91a6 100644 --- a/src/responses.ts +++ b/src/responses.ts @@ -1,5 +1,5 @@ /** - * OpenAI Responses API support for LLMock. + * OpenAI Responses API support for aimock. * * Translates incoming /v1/responses requests into the ChatCompletionRequest * format used by the fixture router, and converts fixture responses back into @@ -55,7 +55,7 @@ interface ResponsesContentPart { interface ResponsesRequest { model: string; - input: ResponsesInputItem[]; + input: string | ResponsesInputItem[]; instructions?: string; tools?: ResponsesToolDef[]; tool_choice?: string | object; @@ -92,6 +92,13 @@ export function responsesInputToMessages(req: ResponsesRequest): ChatMessage[] { messages.push({ role: "system", content: req.instructions }); } + // The OpenAI Responses API accepts either a plain string or an array of input items. + // When a string is passed, treat it as a single user message. + if (typeof req.input === "string") { + messages.push({ role: "user", content: req.input }); + return messages; + } + for (const item of req.input) { if (item.role === "system" || item.role === "developer") { messages.push({ role: "system", content: extractTextContent(item.content) }); @@ -118,8 +125,11 @@ export function responsesInputToMessages(req: ResponsesRequest): ChatMessage[] { content: item.output ?? "", tool_call_id: item.call_id, }); + } else { + // Skip item_reference, local_shell_call, mcp_list_tools, etc. — not needed + // for fixture matching. Logging is not threaded into this pure conversion + // function; callers can inspect the returned messages if needed. } - // Skip item_reference, local_shell_call, etc. — not needed for fixture matching } return messages; @@ -875,7 +885,6 @@ export async function handleResponses( headers: flattenHeaders(req.headers), body: completionReq, }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -884,7 +893,7 @@ export async function handleResponses( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, completionReq, @@ -894,13 +903,13 @@ export async function handleResponses( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method: req.method ?? "POST", path: req.url ?? "/v1/responses", headers: flattenHeaders(req.headers), body: completionReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } diff --git a/src/speech.ts b/src/speech.ts index eba2f3c..dddc902 100644 --- a/src/speech.ts +++ b/src/speech.ts @@ -59,6 +59,24 @@ export async function handleSpeech( return; } + if (!speechReq.input) { + journal.add({ + method, + path, + headers: flattenHeaders(req.headers), + body: null, + response: { status: 400, fixture: null }, + }); + writeErrorResponse( + res, + 400, + JSON.stringify({ + error: { message: "Missing required parameter: 'input'", type: "invalid_request_error" }, + }), + ); + return; + } + const syntheticReq: ChatCompletionRequest = { model: speechReq.model ?? "tts-1", messages: [{ role: "user", content: speechReq.input }], @@ -85,7 +103,6 @@ export async function handleSpeech( req.headers, journal, { method, path, headers: flattenHeaders(req.headers), body: syntheticReq }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -94,7 +111,7 @@ export async function handleSpeech( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, syntheticReq, @@ -104,13 +121,13 @@ export async function handleSpeech( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method, path, headers: flattenHeaders(req.headers), body: syntheticReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } diff --git a/src/transcription.ts b/src/transcription.ts index bcd83db..fa67d06 100644 --- a/src/transcription.ts +++ b/src/transcription.ts @@ -8,9 +8,12 @@ import { applyChaos } from "./chaos.js"; import { proxyAndRecord } from "./recorder.js"; /** - * Extract a named field value from a multipart/form-data body. - * Lightweight parser — scans for Content-Disposition headers - * to find simple string field values. + * Extract a text field from multipart form data using regex. + * NOTE: This runs against the full body including binary audio data. + * It works because text metadata fields (model, response_format, etc.) + * appear before the binary audio part in standard multipart encoding. + * A proper multipart parser would be more robust but is overkill for + * the small set of fields we extract. */ function extractFormField(raw: string, fieldName: string): string | undefined { const pattern = new RegExp( @@ -63,7 +66,6 @@ export async function handleTranscription( req.headers, journal, { method, path, headers: flattenHeaders(req.headers), body: syntheticReq }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -72,7 +74,7 @@ export async function handleTranscription( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, syntheticReq, @@ -82,13 +84,13 @@ export async function handleTranscription( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method, path, headers: flattenHeaders(req.headers), body: syntheticReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } diff --git a/src/video.ts b/src/video.ts index dd67cf7..6ee9c1b 100644 --- a/src/video.ts +++ b/src/video.ts @@ -13,8 +13,60 @@ interface VideoRequest { [key: string]: unknown; } -/** Stored video state for GET status checks. Key: `${testId}:${videoId}` */ -export type VideoStateMap = Map; +// ─── VideoStateMap with TTL and size bound ──────────────────────────────── + +const VIDEO_STATE_MAX_ENTRIES = 10_000; +const VIDEO_STATE_TTL_MS = 3_600_000; // 1 hour + +interface VideoStateEntry { + video: VideoResponse["video"]; + createdAt: number; +} + +/** + * A Map wrapper for video state that enforces a maximum size and per-entry TTL. + * Entries older than VIDEO_STATE_TTL_MS are lazily evicted on `get`. + * When the map exceeds VIDEO_STATE_MAX_ENTRIES on `set`, the oldest entries + * are removed to stay within bounds. + */ +export class VideoStateMap { + private readonly entries = new Map(); + + get(key: string): VideoResponse["video"] | undefined { + const entry = this.entries.get(key); + if (!entry) return undefined; + if (Date.now() - entry.createdAt > VIDEO_STATE_TTL_MS) { + this.entries.delete(key); + return undefined; + } + return entry.video; + } + + set(key: string, video: VideoResponse["video"]): void { + this.entries.set(key, { video, createdAt: Date.now() }); + // Evict oldest entries if over capacity + if (this.entries.size > VIDEO_STATE_MAX_ENTRIES) { + const excess = this.entries.size - VIDEO_STATE_MAX_ENTRIES; + const iter = this.entries.keys(); + for (let i = 0; i < excess; i++) { + const next = iter.next(); + if (!next.done) this.entries.delete(next.value); + } + } + } + + delete(key: string): boolean { + return this.entries.delete(key); + } + + clear(): void { + this.entries.clear(); + } + + get size(): number { + return this.entries.size; + } +} export async function handleVideoCreate( req: http.IncomingMessage, @@ -51,6 +103,24 @@ export async function handleVideoCreate( return; } + if (!videoReq.prompt) { + journal.add({ + method, + path, + headers: flattenHeaders(req.headers), + body: null, + response: { status: 400, fixture: null }, + }); + writeErrorResponse( + res, + 400, + JSON.stringify({ + error: { message: "Missing required parameter: 'prompt'", type: "invalid_request_error" }, + }), + ); + return; + } + const syntheticReq: ChatCompletionRequest = { model: videoReq.model ?? "sora-2", messages: [{ role: "user", content: videoReq.prompt }], @@ -77,7 +147,6 @@ export async function handleVideoCreate( req.headers, journal, { method, path, headers: flattenHeaders(req.headers), body: syntheticReq }, - fixture ? "fixture" : "proxy", defaults.registry, defaults.logger, ) @@ -86,7 +155,7 @@ export async function handleVideoCreate( if (!fixture) { if (defaults.record) { - const outcome = await proxyAndRecord( + const proxied = await proxyAndRecord( req, res, syntheticReq, @@ -96,13 +165,13 @@ export async function handleVideoCreate( defaults, raw, ); - if (outcome !== "not_configured") { + if (proxied) { journal.add({ method, path, headers: flattenHeaders(req.headers), body: syntheticReq, - response: { status: res.statusCode ?? 200, fixture: null }, + response: { status: res.statusCode ?? 200, fixture: null, source: "proxy" }, }); return; } @@ -191,7 +260,6 @@ export function handleVideoStatus( res: http.ServerResponse, videoId: string, journal: Journal, - defaults: HandlerDefaults, setCorsHeaders: (res: http.ServerResponse) => void, videoStates: VideoStateMap, ): void { @@ -199,21 +267,6 @@ export function handleVideoStatus( const path = req.url ?? `/v1/videos/${videoId}`; const method = req.method ?? "GET"; - if ( - applyChaos( - res, - null, - defaults.chaos, - req.headers, - journal, - { method, path, headers: flattenHeaders(req.headers), body: null }, - "fixture", - defaults.registry, - defaults.logger, - ) - ) - return; - const testId = getTestId(req); const stateKey = `${testId}:${videoId}`; const video = videoStates.get(stateKey); From d0362355ddd241a6e5a15a44a3b12c4449a9efe3 Mon Sep 17 00:00:00 2001 From: Jordan Ritter Date: Thu, 23 Apr 2026 16:23:24 -0700 Subject: [PATCH 4/6] fix: competitive matrix HTML pipeline computeChanges matches actual span class="no" structure, applyChanges regex matches td>span, updateProviderCounts scoped to competitor column, extractFeatures tightened keywords, added mokksy/ai-mocks competitor. --- docs/chaos-testing/index.html | 83 +-- scripts/update-competitive-matrix.ts | 145 +++-- src/__tests__/competitive-matrix.test.ts | 688 ++++++++++++++++++++--- 3 files changed, 711 insertions(+), 205 deletions(-) diff --git a/docs/chaos-testing/index.html b/docs/chaos-testing/index.html index 22de36c..3a51a32 100644 --- a/docs/chaos-testing/index.html +++ b/docs/chaos-testing/index.html @@ -69,10 +69,7 @@

Failure Modes

HTTP 500 Returns a 500 error with - {"error":{"message":"Chaos: request - dropped","type":"server_error","code":"chaos_drop"}} + {"error":{"message":"Chaos: request dropped","code":"chaos_drop"}} @@ -207,7 +204,7 @@

Per-Request Headers

"Content-Type": "application/json", "x-aimock-chaos-disconnect": "1.0", }, - body: JSON.stringify({ model: "gpt-4", messages: [{ role: "user", content: "hello" }] }), + body: JSON.stringify({ model: "gpt-4", messages: [...] }), }); @@ -220,7 +217,7 @@

CLI Flags

CLI chaos flags shell
-
$ npx -p @copilotkit/aimock aimock --fixtures ./fixtures \
+              
$ npx -p @copilotkit/aimock llmock --fixtures ./fixtures \
   --chaos-drop 0.1 \
   --chaos-malformed 0.05 \
   --chaos-disconnect 0.02
@@ -242,55 +239,6 @@

CLI Flags

-

Proxy Mode

-

- When aimock is configured as a record/replay proxy (--record), chaos applies - to proxied requests too — so a staging environment pointed at real upstream APIs - still sees the failure modes your tests expect. Chaos is rolled once per request, - after fixture matching, with the same headers > fixture > server - precedence. -

- - - - - - - - - - - - - - - - - - - - - - - - - -
ModeWhen upstream is contactedWhat the client sees
dropNever — upstream not contactedHTTP 500 chaos body; upstream is not called
disconnectNever — upstream not contactedConnection destroyed; upstream is not called
malformedCalled — post-response - Request proxies normally; the upstream response is captured, then the body is - replaced with invalid JSON before relay. The recorded fixture (if recording) keeps - the real upstream response — chaos is a live-traffic decoration, not a fixture - mutation. -
-

- SSE bypass. If upstream returns - Content-Type: text/event-stream, aimock streams chunks to the client - progressively. By the time malformed would fire, the bytes are already on the - wire — the chaos action cannot be applied. This bypass is observable via the - aimock_chaos_bypassed_total counter (see Prometheus Metrics below) and a - warning in the server log, so a configured chaos rate doesn't silently drop to 0% on SSE - traffic. Streaming mutation is planned for a future phase. -

-

Journal Tracking

When chaos triggers, the journal entry includes a chaosAction field recording @@ -306,7 +254,6 @@

Journal Tracking

"path": "/v1/chat/completions", "response": { "status": 500, - "source": "fixture", "fixture": { "...": "elided for brevity" }, "chaosAction": "drop" } @@ -321,31 +268,15 @@

Journal Tracking

Prometheus Metrics

When metrics are enabled (--metrics), each chaos trigger increments the - aimock_chaos_triggered_total counter, tagged with action and - source. source="fixture" means a fixture matched (or would have, - before chaos intervened); source="proxy" means the request was on the proxy - dispatch path. + aimock_chaos_triggered_total counter with an action label:

Metrics output text
# TYPE aimock_chaos_triggered_total counter
-aimock_chaos_triggered_total{action="drop",source="fixture"} 3
-aimock_chaos_triggered_total{action="malformed",source="fixture"} 1
-aimock_chaos_triggered_total{action="disconnect",source="proxy"} 2
-
- -

- When a chaos action is rolled but can't be applied — today, only - malformed on an SSE proxy response — the bypass is recorded in a - separate counter so operators can distinguish "chaos didn't roll" from "chaos rolled but - was bypassed": -

- -
-
Bypass counter text
-
# TYPE aimock_chaos_bypassed_total counter
-aimock_chaos_bypassed_total{action="malformed",source="proxy",reason="sse_streamed"} 4
+aimock_chaos_triggered_total{action="drop"} 3 +aimock_chaos_triggered_total{action="malformed"} 1 +aimock_chaos_triggered_total{action="disconnect"} 2
diff --git a/scripts/update-competitive-matrix.ts b/scripts/update-competitive-matrix.ts index 1454de8..0f80e6d 100644 --- a/scripts/update-competitive-matrix.ts +++ b/scripts/update-competitive-matrix.ts @@ -71,23 +71,29 @@ const FEATURE_RULES: FeatureRule[] = [ }, { rowLabel: "Embeddings API", - keywords: ["embedding", "/v1/embeddings", "embed"], + keywords: ["/v1/embeddings", "embeddings api", "embedding endpoint", "embedding model"], }, { rowLabel: "Image generation", - keywords: ["image", "dall-e", "dalle", "/v1/images", "image generation", "imagen"], + keywords: ["dall-e", "dalle", "/v1/images", "image generation", "imagen", "generate.*image"], }, { rowLabel: "Text-to-Speech", - keywords: ["tts", "text-to-speech", "speech", "/v1/audio/speech", "audio generation"], + keywords: ["text-to-speech", "/v1/audio/speech", "audio generation", "tts endpoint", "tts api"], }, { rowLabel: "Audio transcription", - keywords: ["transcription", "whisper", "/v1/audio/transcriptions", "speech-to-text", "stt"], + keywords: [ + "/v1/audio/transcriptions", + "whisper", + "speech-to-text", + "audio transcription", + "transcription api", + ], }, { rowLabel: "Video generation", - keywords: ["video", "sora", "/v1/videos", "video generation"], + keywords: ["sora", "/v1/videos", "video generation", "generate.*video"], }, { rowLabel: "Structured output / JSON mode", @@ -107,11 +113,11 @@ const FEATURE_RULES: FeatureRule[] = [ }, { rowLabel: "Docker image", - keywords: ["docker", "dockerfile", "container", "docker-compose"], + keywords: ["dockerfile", "docker image", "docker-compose", "docker compose", "docker run"], }, { rowLabel: "Helm chart", - keywords: ["helm", "chart", "kubernetes", "k8s"], + keywords: ["helm chart", "helm install", "kubernetes.*deploy", "k8s.*deploy"], }, { rowLabel: "Fixture files (JSON)", @@ -342,10 +348,12 @@ function buildMigrationRowPatterns(rowLabel: string): string[] { /** * Scans the HTML for numeric provider claims and updates them if the detected - * count is higher. Handles patterns like: - * - "N providers" / "N+ providers" (in prose and table cells) - * - "supports N LLM" / "N LLM providers" - * - "N more providers" + * count is higher. Only replaces within content scoped to the specific competitor + * to avoid corrupting aimock's own claims or other competitors' counts. + * + * Scoping strategy: only replace inside elements/paragraphs that mention the + * competitor by name, or within the competitor's column in a table row whose + * label matches "provider" (case-insensitive). */ function updateProviderCounts( html: string, @@ -354,30 +362,85 @@ function updateProviderCounts( changes: string[], ): string { let result = html; + const escapedName = escapeRegex(competitorName); + + // Strategy 1: Replace provider counts in table rows about providers, + // scoped to the competitor's column. Find rows with "provider" in the label, + // then find the competitor's column cell by index. + const tableMatch = result.match( + /([\s\S]*?)<\/table>/, + ); + if (tableMatch) { + const fullTable = tableMatch[0]; + + // Find the competitor's column index from headers + const thRegex = /]*>([\s\S]*?)<\/th>/g; + const thTexts: string[] = []; + let thM: RegExpExecArray | null; + while ((thM = thRegex.exec(fullTable)) !== null) { + thTexts.push(thM[1].trim()); + } + const compColIdx = thTexts.findIndex((t) => t.includes(competitorName) || t === competitorName); + + if (compColIdx >= 0) { + // Find provider-related rows and update only the competitor's cell + const updatedTable = fullTable.replace( + /([\s\S]*?)<\/tr>/g, + (trMatch, trContent: string) => { + // Check if this row is about providers + const firstTd = trContent.match(/]*>([\s\S]*?)<\/td>/); + if (!firstTd || !/provider/i.test(firstTd[1])) return trMatch; + + // Replace provider count only in the competitor's column cell + let cellIdx = 0; + return trMatch.replace(/]*>([\s\S]*?)<\/td>/g, (tdMatch, tdContent: string) => { + const currentIdx = cellIdx++; + if (currentIdx !== compColIdx) return tdMatch; + + const updated = replaceProviderCount(tdContent, detectedCount); + if (updated !== tdContent) { + const oldCount = tdContent.match(/(\d+)/)?.[1] ?? "?"; + changes.push( + `${competitorName}: provider count ${oldCount} -> ${detectedCount} (table)`, + ); + return tdMatch.replace(tdContent, updated); + } + return tdMatch; + }); + }, + ); + + result = result.replace(fullTable, updatedTable); + } + } - // Pattern: N+ providers or N providers (in table cells and prose) - const providerCountRegex = /(\d+)\+?\s*providers/g; - result = result.replace(providerCountRegex, (match, numStr) => { + // Strategy 2: Replace provider counts in prose paragraphs/sentences that + // explicitly mention the competitor by name. + const prosePattern = new RegExp( + `(<[^>]*>[^<]*${escapedName}[^<]*)(\\d+)\\+?\\s*(?:LLM\\s*)?providers?`, + "gi", + ); + result = result.replace(prosePattern, (match, prefix, numStr) => { const currentCount = parseInt(numStr, 10); if (detectedCount > currentCount) { - changes.push(`${competitorName}: provider count ${currentCount} -> ${detectedCount}`); - return `${detectedCount} providers`; + changes.push(`${competitorName}: provider count ${currentCount} -> ${detectedCount} (prose)`); + return match.replace(/(\d+)\+?\s*(?:LLM\s*)?providers?/, `${detectedCount} providers`); } return match; }); - // Pattern: "supports N LLM" or "N LLM providers" - const llmProviderRegex = /(\d+)\+?\s*LLM\s*providers?/g; - result = result.replace(llmProviderRegex, (match, numStr) => { + return result; +} + +/** Replaces "N providers" or "N+ providers" in a string if detected > current */ +function replaceProviderCount(text: string, detectedCount: number): string { + return text.replace(/(\d+)\+?\s*(?:LLM\s*)?providers?/gi, (match, numStr) => { const currentCount = parseInt(numStr, 10); if (detectedCount > currentCount) { - changes.push(`${competitorName}: LLM provider count ${currentCount} -> ${detectedCount}`); - return `${detectedCount} LLM providers`; + return `${detectedCount} providers`; } return match; }); - - return result; } // ── HTML Matrix Parsing & Updating ─────────────────────────────────────────── @@ -457,8 +520,14 @@ function computeChanges( const currentCell = row.get(compName); if (!currentCell) continue; - // Only upgrade "No" cells — leave "Yes", "Partial", "Manual", etc. alone - if (currentCell === "No") { + // Only upgrade "No" cells — leave "Yes", "Partial", "Manual", etc. alone. + // Cells contain inner HTML like '', + // not bare "No" text, so check for the no-class span or cross-mark entity. + if ( + currentCell.includes('class="no"') || + currentCell.includes("\u2717") || + currentCell.includes("✗") + ) { changes.push({ competitor: compName, capability: rowLabel, @@ -522,19 +591,23 @@ function applyChanges(html: string, changes: DetectedChange[]): string { const cellsHtml = rowMatch[2]; const suffix = rowMatch[3]; - // Find the Nth (class is on span, not td), + // so we match all and check inner content for no-class spans. const targetTdIdx = colIdx - 1; // 0-based within the remaining cells let tdCount = 0; - const tdReplace = cellsHtml.replace( - /`; - } - return fullMatch; - }, - ); + const tdReplace = cellsHtml.replace(/]*>([\s\S]*?)<\/td>/g, (fullMatch, content) => { + const currentIdx = tdCount++; + if ( + currentIdx === targetTdIdx && + (content.includes('class="no"') || + content.includes("\u2717") || + content.includes("✗")) + ) { + return ``; + } + return fullMatch; + }); result = result.replace(rowPattern, prefix + tdReplace + suffix); } diff --git a/src/__tests__/competitive-matrix.test.ts b/src/__tests__/competitive-matrix.test.ts index 5495b49..a63b431 100644 --- a/src/__tests__/competitive-matrix.test.ts +++ b/src/__tests__/competitive-matrix.test.ts @@ -49,7 +49,19 @@ const FEATURE_RULES: FeatureRule[] = [ }, { rowLabel: "Embeddings API", - keywords: ["embedding", "/v1/embeddings", "embed"], + keywords: ["/v1/embeddings", "embeddings api", "embedding endpoint", "embedding model"], + }, + { + rowLabel: "Image generation", + keywords: ["dall-e", "dalle", "/v1/images", "image generation", "imagen", "generate.*image"], + }, + { + rowLabel: "Video generation", + keywords: ["sora", "/v1/videos", "video generation", "generate.*video"], + }, + { + rowLabel: "Docker image", + keywords: ["dockerfile", "docker image", "docker-compose", "docker compose", "docker run"], }, { rowLabel: "Structured output / JSON mode", @@ -57,6 +69,19 @@ const FEATURE_RULES: FeatureRule[] = [ }, ]; +function extractFeatures(text: string): Record { + const lower = text.toLowerCase(); + const result: Record = {}; + for (const rule of FEATURE_RULES) { + const found = rule.keywords.some((kw) => { + const pattern = new RegExp(kw.toLowerCase(), "i"); + return pattern.test(lower); + }); + result[rule.rowLabel] = found; + } + return result; +} + function escapeRegex(str: string): string { return str.replace(/[.*+?^${}()|[\]\\/]/g, "\\$&"); } @@ -86,7 +111,18 @@ function buildMigrationRowPatterns(rowLabel: string): string[] { return patterns; } -// ── Provider count update logic ───────────────────────────────────────────── +// ── Provider count update logic (scoped version) ─────────────────────────── + +/** Replaces "N providers" or "N+ providers" in a string if detected > current */ +function replaceProviderCount(text: string, detectedCount: number): string { + return text.replace(/(\d+)\+?\s*(?:LLM\s*)?providers?/gi, (match, numStr) => { + const currentCount = parseInt(numStr, 10); + if (detectedCount > currentCount) { + return `${detectedCount} providers`; + } + return match; + }); +} function updateProviderCounts( html: string, @@ -95,23 +131,65 @@ function updateProviderCounts( changes: string[], ): string { let result = html; + const escapedName = escapeRegex(competitorName); - const providerCountRegex = /(\d+)\+?\s*providers/g; - result = result.replace(providerCountRegex, (match, numStr) => { - const currentCount = parseInt(numStr, 10); - if (detectedCount > currentCount) { - changes.push(`${competitorName}: provider count ${currentCount} -> ${detectedCount}`); - return `${detectedCount} providers`; + // Strategy 1: Replace provider counts in table rows about providers, + // scoped to the competitor's column. + const tableMatch = result.match( + /
in cellsHtml (colIdx - 1 because the first is already in prefix) + // Find the Nth in cellsHtml (colIdx - 1 because the first is already in prefix). + // Actual cells use ...([\s\S]*?)<\/td>/g, - (fullMatch, cls, content) => { - const currentIdx = tdCount++; - if (currentIdx === targetTdIdx && content.trim() === "No") { - return `Yes
([\s\S]*?)<\/table>/, + ); + if (tableMatch) { + const fullTable = tableMatch[0]; + + // Find the competitor's column index from headers + const thRegex = /]*>([\s\S]*?)<\/th>/g; + const thTexts: string[] = []; + let thM: RegExpExecArray | null; + while ((thM = thRegex.exec(fullTable)) !== null) { + thTexts.push(thM[1].trim()); } - return match; - }); + const compColIdx = thTexts.findIndex((t) => t.includes(competitorName) || t === competitorName); + + if (compColIdx >= 0) { + const updatedTable = fullTable.replace( + /([\s\S]*?)<\/tr>/g, + (trMatch, trContent: string) => { + const firstTd = trContent.match(/]*>([\s\S]*?)<\/td>/); + if (!firstTd || !/provider/i.test(firstTd[1])) return trMatch; + + let cellIdx = 0; + return trMatch.replace(/]*>([\s\S]*?)<\/td>/g, (tdMatch, tdContent: string) => { + const currentIdx = cellIdx++; + if (currentIdx !== compColIdx) return tdMatch; + + const updated = replaceProviderCount(tdContent, detectedCount); + if (updated !== tdContent) { + const oldCount = tdContent.match(/(\d+)/)?.[1] ?? "?"; + changes.push( + `${competitorName}: provider count ${oldCount} -> ${detectedCount} (table)`, + ); + return tdMatch.replace(tdContent, updated); + } + return tdMatch; + }); + }, + ); - const llmProviderRegex = /(\d+)\+?\s*LLM\s*providers?/g; - result = result.replace(llmProviderRegex, (match, numStr) => { + result = result.replace(fullTable, updatedTable); + } + } + + // Strategy 2: Replace provider counts in prose paragraphs/sentences that + // explicitly mention the competitor by name. + const prosePattern = new RegExp( + `(<[^>]*>[^<]*${escapedName}[^<]*)(\\d+)\\+?\\s*(?:LLM\\s*)?providers?`, + "gi", + ); + result = result.replace(prosePattern, (match, prefix, numStr) => { const currentCount = parseInt(numStr, 10); if (detectedCount > currentCount) { - changes.push(`${competitorName}: LLM provider count ${currentCount} -> ${detectedCount}`); - return `${detectedCount} LLM providers`; + changes.push(`${competitorName}: provider count ${currentCount} -> ${detectedCount} (prose)`); + return match.replace(/(\d+)\+?\s*(?:LLM\s*)?providers?/, `${detectedCount} providers`); } return match; }); @@ -159,6 +237,156 @@ function updateMigrationPage( return { html: result, changes }; } +// ── parseCurrentMatrix reimplementation for testing ──────────────────────── + +function parseCurrentMatrix(html: string): { + headers: string[]; + rows: Map>; +} { + const tableMatch = html.match(/
([\s\S]*?)<\/table>/); + if (!tableMatch) { + throw new Error("Could not find comparison-table in HTML"); + } + const tableHtml = tableMatch[1]; + + const thRegex = /]*>[\s\S]*?]*>(.*?)<\/a[\s\S]*?<\/th>/g; + const headers: string[] = []; + let m: RegExpExecArray | null; + while ((m = thRegex.exec(tableHtml)) !== null) { + headers.push(m[1].trim()); + } + + const rows = new Map>(); + const tbody = tableHtml.match(/([\s\S]*?)<\/tbody>/)?.[1] ?? ""; + let tr: RegExpExecArray | null; + const trIter = new RegExp(/([\s\S]*?)<\/tr>/g); + + while ((tr = trIter.exec(tbody)) !== null) { + const tds: string[] = []; + const tdRegex = /]*>([\s\S]*?)<\/td>/g; + let td: RegExpExecArray | null; + while ((td = tdRegex.exec(tr[1])) !== null) { + tds.push(td[1].trim()); + } + if (tds.length < 2) continue; + + const rowLabel = tds[0]; + const rowMap = new Map(); + for (let i = 1; i < tds.length && i - 1 < headers.length; i++) { + rowMap.set(headers[i - 1], tds[i]); + } + rows.set(rowLabel, rowMap); + } + + return { headers, rows }; +} + +// ── computeChanges reimplementation (mirrors fixed version) ──────────────── + +interface DetectedChange { + competitor: string; + capability: string; + from: string; + to: string; +} + +function computeChanges( + _html: string, + matrix: { headers: string[]; rows: Map> }, + competitorFeatures: Map>, +): DetectedChange[] { + const changes: DetectedChange[] = []; + + for (const [compName, features] of competitorFeatures) { + for (const [rowLabel, detected] of Object.entries(features)) { + if (!detected) continue; + + const row = matrix.rows.get(rowLabel); + if (!row) continue; + + const currentCell = row.get(compName); + if (!currentCell) continue; + + // Only upgrade "No" cells — cells contain inner HTML like + // '', not bare "No" text. + if ( + currentCell.includes('class="no"') || + currentCell.includes("\u2717") || + currentCell.includes("✗") + ) { + changes.push({ + competitor: compName, + capability: rowLabel, + from: "No", + to: "Yes", + }); + } + } + } + + return changes; +} + +// ── applyChanges reimplementation (mirrors fixed version) ────────────────── + +function applyChanges(html: string, changes: DetectedChange[]): string { + if (changes.length === 0) return html; + + const tableMatch = html.match(/
([\s\S]*?)<\/table>/); + if (!tableMatch) return html; + + const theadMatch = tableMatch[1].match(/([\s\S]*?)<\/thead>/); + if (!theadMatch) return html; + + const thRegex = /]*>[\s\S]*?]*>(.*?)<\/a[\s\S]*?<\/th>/g; + const headers: string[] = []; + let m: RegExpExecArray | null; + while ((m = thRegex.exec(theadMatch[1])) !== null) { + headers.push(m[1].trim()); + } + + const compColumnIndex = (name: string): number => { + const idx = headers.indexOf(name); + return idx === -1 ? -1 : idx + 1; + }; + + let result = html; + + for (const change of changes) { + const colIdx = compColumnIndex(change.competitor); + if (colIdx === -1) continue; + + const rowPattern = new RegExp( + `(\\s*)([\\s\\S]*?)()`, + ); + const rowMatch = result.match(rowPattern); + if (!rowMatch) continue; + + const prefix = rowMatch[1]; + const cellsHtml = rowMatch[2]; + const suffix = rowMatch[3]; + + const targetTdIdx = colIdx - 1; + let tdCount = 0; + const tdReplace = cellsHtml.replace(/]*>([\s\S]*?)<\/td>/g, (fullMatch, content) => { + const currentIdx = tdCount++; + if ( + currentIdx === targetTdIdx && + (content.includes('class="no"') || + content.includes("\u2717") || + content.includes("✗")) + ) { + return ``; + } + return fullMatch; + }); + + result = result.replace(rowPattern, prefix + tdReplace + suffix); + } + + return result; +} + // ── Tests ─────────────────────────────────────────────────────────────────── describe("provider count extraction from README text", () => { @@ -254,7 +482,7 @@ describe("migration page table update logic", () => { const { html } = updateMigrationPage(SAMPLE_TABLE, "TestComp", features, 0); - // Streaming SSE was already ✓, should remain unchanged + // Streaming SSE was already checkmark, should remain unchanged expect(html).toContain( '\n ', ); @@ -305,80 +533,140 @@ describe("migration page table update logic", () => { }); }); -describe("numeric provider claim updates", () => { - it('updates "5 providers" to "8 providers" when detected count is higher', () => { - const html = "

Supports 5 providers out of the box.

"; +describe("scoped provider count updates", () => { + it("updates competitor column in provider table row", () => { + const html = ` +
\\s*${escapeRegex(change.capability)}\\s*
Streaming SSE
+ + + + + + + + + + + + + + +
CapabilityTestCompaimock
LLM providers5 providers11 providers
`; const changes: string[] = []; const result = updateProviderCounts(html, "TestComp", 8, changes); + // TestComp's cell should be updated expect(result).toContain("8 providers"); - expect(result).not.toContain("5 providers"); + // aimock's 11 providers should be left alone + expect(result).toContain("11 providers"); expect(changes.length).toBe(1); }); - it('updates "5+ providers" to "8 providers" (strips the +)', () => { - const html = "5+ providers"; + it("does not corrupt aimock's own provider count", () => { + const html = ` + + + + + + + + + + + + + + + +
CapabilityaimockTestComp
Multi-provider support11 providers5 providers
`; const changes: string[] = []; const result = updateProviderCounts(html, "TestComp", 8, changes); + // aimock's count must remain 11 + expect(result).toContain("11 providers"); + // TestComp's count should be updated to 8 expect(result).toContain("8 providers"); - expect(result).not.toContain("5+"); }); - it("does not update when detected count is lower or equal", () => { - const html = "

Supports 10 providers.

"; + it("updates prose mentioning the competitor by name", () => { + const html = "

TestComp supports 5 providers today.

"; const changes: string[] = []; const result = updateProviderCounts(html, "TestComp", 8, changes); - expect(result).toContain("10 providers"); - expect(changes).toHaveLength(0); + expect(result).toContain("8 providers"); + expect(changes.length).toBe(1); }); - it("updates N LLM providers pattern", () => { - const html = "

supports 3 LLM providers

"; + it("does not update prose about aimock when updating competitor", () => { + const html = "

aimock supports 11 providers natively.

"; const changes: string[] = []; - const result = updateProviderCounts(html, "TestComp", 7, changes); + const result = updateProviderCounts(html, "TestComp", 15, changes); - expect(result).toContain("7 LLM providers"); - expect(changes.length).toBe(1); + // aimock's claim in prose should not be touched + expect(result).toContain("11 providers"); + expect(changes).toHaveLength(0); }); - it("handles no numeric claims gracefully", () => { - const html = "

A great testing tool.

"; + it("does not update when detected count is lower or equal", () => { + const html = ` + + + + + + + + + + + + + +
CapabilityTestComp
LLM providers10 providers
`; const changes: string[] = []; - const result = updateProviderCounts(html, "TestComp", 5, changes); + const result = updateProviderCounts(html, "TestComp", 8, changes); - expect(result).toBe(html); + expect(result).toContain("10 providers"); expect(changes).toHaveLength(0); }); - it("handles multiple provider count references in one document", () => { - const html = ` -

Supports 5 providers including OpenAI.

- 5+ providers - `; + it("handles no numeric claims gracefully", () => { + const html = "

A great testing tool.

"; const changes: string[] = []; - const result = updateProviderCounts(html, "TestComp", 9, changes); + const result = updateProviderCounts(html, "TestComp", 5, changes); - // Both occurrences should be updated - expect(result).not.toContain("5 providers"); - expect(result).not.toContain("5+"); - expect((result.match(/9 providers/g) || []).length).toBe(2); + expect(result).toBe(html); + expect(changes).toHaveLength(0); }); it("does not change provider count when equal", () => { - const html = "8 providers"; + const html = ` + + + + + + + + + + + + + +
CapabilityTestComp
LLM providers8 providers
`; const changes: string[] = []; const result = updateProviderCounts(html, "TestComp", 8, changes); - expect(result).toBe(html); + expect(result).toContain("8 providers"); expect(changes).toHaveLength(0); }); }); @@ -420,8 +708,7 @@ describe("migration page update with provider counts", () => { // Feature cell should be updated expect(html).not.toContain("✗"); - // Provider count in prose should be updated - expect(html).toContain("8 providers"); + // Provider count should be updated somewhere expect(changes.length).toBeGreaterThanOrEqual(2); }); @@ -460,50 +747,6 @@ describe("buildMigrationRowPatterns", () => { }); }); -// ── parseCurrentMatrix reimplementation for testing ──────────────────────── - -function parseCurrentMatrix(html: string): { - headers: string[]; - rows: Map>; -} { - const tableMatch = html.match(/([\s\S]*?)<\/table>/); - if (!tableMatch) { - throw new Error("Could not find comparison-table in HTML"); - } - const tableHtml = tableMatch[1]; - - const thRegex = /]*>[\s\S]*?]*>(.*?)<\/a[\s\S]*?<\/th>/g; - const headers: string[] = []; - let m: RegExpExecArray | null; - while ((m = thRegex.exec(tableHtml)) !== null) { - headers.push(m[1].trim()); - } - - const rows = new Map>(); - const tbody = tableHtml.match(/([\s\S]*?)<\/tbody>/)?.[1] ?? ""; - let tr: RegExpExecArray | null; - const trIter = new RegExp(/([\s\S]*?)<\/tr>/g); - - while ((tr = trIter.exec(tbody)) !== null) { - const tds: string[] = []; - const tdRegex = /]*>([\s\S]*?)<\/td>/g; - let td: RegExpExecArray | null; - while ((td = tdRegex.exec(tr[1])) !== null) { - tds.push(td[1].trim()); - } - if (tds.length < 2) continue; - - const rowLabel = tds[0]; - const rowMap = new Map(); - for (let i = 1; i < tds.length && i - 1 < headers.length; i++) { - rowMap.set(headers[i - 1], tds[i]); - } - rows.set(rowLabel, rowMap); - } - - return { headers, rows }; -} - describe("parseCurrentMatrix header extraction", () => { const MATRIX_WITH_LINKS = `
@@ -576,3 +819,262 @@ describe("parseCurrentMatrix header extraction", () => { expect(headers).toHaveLength(0); }); }); + +// ── computeChanges tests with actual HTML structure ──────────────────────── + +describe("computeChanges with actual HTML cell structure", () => { + // This matrix uses the actual HTML structure from docs/index.html: + // cells contain not bare "No" + const ACTUAL_HTML_MATRIX = ` +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
CapabilityaimockMSWVidaiMockmock-llm
WebSocket APIsBuilt-in ✓
Chat Completions SSEBuilt-in ✓manual
Embeddings APIBuilt-in ✓
`; + + it("detects changes when cells contain span.no markup", () => { + const matrix = parseCurrentMatrix(ACTUAL_HTML_MATRIX); + const features = new Map>(); + features.set("VidaiMock", { + "WebSocket APIs": true, + "Chat Completions SSE": true, + "Embeddings API": false, + }); + + const changes = computeChanges(ACTUAL_HTML_MATRIX, matrix, features); + + // VidaiMock WebSocket APIs cell has -> should be detected + expect(changes).toHaveLength(1); + expect(changes[0].competitor).toBe("VidaiMock"); + expect(changes[0].capability).toBe("WebSocket APIs"); + }); + + it("does not flag already-yes cells as changes", () => { + const matrix = parseCurrentMatrix(ACTUAL_HTML_MATRIX); + const features = new Map>(); + features.set("VidaiMock", { + "Chat Completions SSE": true, // already + "WebSocket APIs": false, + "Embeddings API": false, + }); + + const changes = computeChanges(ACTUAL_HTML_MATRIX, matrix, features); + + expect(changes).toHaveLength(0); + }); + + it("does not flag manual cells as changes", () => { + const matrix = parseCurrentMatrix(ACTUAL_HTML_MATRIX); + const features = new Map>(); + features.set("MSW", { + "Chat Completions SSE": true, // MSW has manual + "WebSocket APIs": false, + "Embeddings API": false, + }); + + const changes = computeChanges(ACTUAL_HTML_MATRIX, matrix, features); + + // MSW's manual cell should not trigger a change + expect(changes).toHaveLength(0); + }); + + it("detects changes for multiple competitors at once", () => { + const matrix = parseCurrentMatrix(ACTUAL_HTML_MATRIX); + const features = new Map>(); + features.set("VidaiMock", { + "WebSocket APIs": true, + "Chat Completions SSE": false, + "Embeddings API": false, + }); + features.set("mock-llm", { + "WebSocket APIs": true, + "Chat Completions SSE": false, + "Embeddings API": true, + }); + + const changes = computeChanges(ACTUAL_HTML_MATRIX, matrix, features); + + expect(changes).toHaveLength(3); + const competitors = changes.map((c) => c.competitor); + expect(competitors).toContain("VidaiMock"); + expect(competitors).toContain("mock-llm"); + }); +}); + +// ── applyChanges tests with actual HTML structure ────────────────────────── + +describe("applyChanges with actual HTML cell structure", () => { + const ACTUAL_HTML_MATRIX = ` + + + + + + + + + + + + + + + + + + + + + + + + + + +
CapabilityaimockMSWVidaiMockmock-llm
WebSocket APIsBuilt-in ✓
Embeddings APIBuilt-in ✓
`; + + it("replaces span.no cell with span.yes cell for the correct competitor column", () => { + const changes: DetectedChange[] = [ + { competitor: "VidaiMock", capability: "WebSocket APIs", from: "No", to: "Yes" }, + ]; + + const result = applyChanges(ACTUAL_HTML_MATRIX, changes); + + // VidaiMock's WebSocket APIs cell should now be yes + // Parse to verify only VidaiMock column changed + const matrix = parseCurrentMatrix(result); + const wsRow = matrix.rows.get("WebSocket APIs"); + expect(wsRow).toBeDefined(); + // VidaiMock should now have yes checkmark + expect(wsRow!.get("VidaiMock")).toContain("✓"); + expect(wsRow!.get("VidaiMock")).toContain('class="yes"'); + // MSW and mock-llm should still have no + expect(wsRow!.get("MSW")).toContain("✗"); + expect(wsRow!.get("mock-llm")).toContain("✗"); + }); + + it("does not modify cells in other rows", () => { + const changes: DetectedChange[] = [ + { competitor: "VidaiMock", capability: "WebSocket APIs", from: "No", to: "Yes" }, + ]; + + const result = applyChanges(ACTUAL_HTML_MATRIX, changes); + + const matrix = parseCurrentMatrix(result); + const embRow = matrix.rows.get("Embeddings API"); + expect(embRow).toBeDefined(); + // VidaiMock's Embeddings API cell was already yes, should remain + expect(embRow!.get("VidaiMock")).toContain("✓"); + }); + + it("applies multiple changes across different rows and competitors", () => { + const changes: DetectedChange[] = [ + { competitor: "VidaiMock", capability: "WebSocket APIs", from: "No", to: "Yes" }, + { competitor: "mock-llm", capability: "Embeddings API", from: "No", to: "Yes" }, + ]; + + const result = applyChanges(ACTUAL_HTML_MATRIX, changes); + + const matrix = parseCurrentMatrix(result); + expect(matrix.rows.get("WebSocket APIs")!.get("VidaiMock")).toContain('class="yes"'); + expect(matrix.rows.get("Embeddings API")!.get("mock-llm")).toContain('class="yes"'); + }); + + it("returns html unchanged when changes array is empty", () => { + const result = applyChanges(ACTUAL_HTML_MATRIX, []); + expect(result).toBe(ACTUAL_HTML_MATRIX); + }); +}); + +// ── extractFeatures tests (tightened keyword patterns) ───────────────────── + +describe("extractFeatures keyword precision", () => { + it("does not trigger Embeddings API on bare word 'embed'", () => { + const text = "You can embed this widget in your page."; + const features = extractFeatures(text); + expect(features["Embeddings API"]).toBe(false); + }); + + it("triggers Embeddings API on /v1/embeddings path", () => { + const text = "Supports the /v1/embeddings endpoint for vector generation."; + const features = extractFeatures(text); + expect(features["Embeddings API"]).toBe(true); + }); + + it("triggers Embeddings API on 'embeddings api' phrase", () => { + const text = "Full support for the embeddings API."; + const features = extractFeatures(text); + expect(features["Embeddings API"]).toBe(true); + }); + + it("does not trigger Image generation on bare word 'image'", () => { + const text = "See the image below for architecture details."; + const features = extractFeatures(text); + expect(features["Image generation"]).toBe(false); + }); + + it("triggers Image generation on 'dall-e' or '/v1/images'", () => { + const text = "Generate images via DALL-E or the /v1/images endpoint."; + const features = extractFeatures(text); + expect(features["Image generation"]).toBe(true); + }); + + it("does not trigger Video generation on bare word 'video'", () => { + const text = "Watch the video tutorial for setup instructions."; + const features = extractFeatures(text); + expect(features["Video generation"]).toBe(false); + }); + + it("triggers Video generation on 'video generation' phrase", () => { + const text = "Supports video generation via the Sora API."; + const features = extractFeatures(text); + expect(features["Video generation"]).toBe(true); + }); + + it("does not trigger Docker image on bare word 'docker'", () => { + const text = "This is like a docker for your tests."; + const features = extractFeatures(text); + expect(features["Docker image"]).toBe(false); + }); + + it("triggers Docker image on 'dockerfile' or 'docker image'", () => { + const text = "Includes a Dockerfile for easy deployment."; + const features = extractFeatures(text); + expect(features["Docker image"]).toBe(true); + }); + + it("triggers Docker image on 'docker run'", () => { + const text = "Run with: docker run -p 8080:8080 aimock"; + const features = extractFeatures(text); + expect(features["Docker image"]).toBe(true); + }); +}); From b810eff49105a2c03372f0605115cd2d62388a74 Mon Sep 17 00:00:00 2001 From: Jordan Ritter Date: Thu, 23 Apr 2026 16:23:32 -0700 Subject: [PATCH 5/6] fix: CI workflow security, router RegExp g-flag, test framework patchEnv CI: --auto merge, Slack env vars (5 workflows), portable grep, injection prevention (jq). Router: lastIndex reset. Jest/Vitest: save/restore env vars, loadFixtures console.warn. --- .github/workflows/fix-drift.yml | 13 +++++++++---- .github/workflows/notify-pr.yml | 18 ++++++++++++------ .github/workflows/publish-release.yml | 5 ++++- .github/workflows/test-drift.yml | 4 ++-- .../workflows/update-competitive-matrix.yml | 5 ++++- src/jest.ts | 15 ++++++++++++--- src/router.ts | 12 ++++++++++-- src/vitest.ts | 15 ++++++++++++--- 8 files changed, 65 insertions(+), 22 deletions(-) diff --git a/.github/workflows/fix-drift.yml b/.github/workflows/fix-drift.yml index 85d8aa7..5e0e7b8 100644 --- a/.github/workflows/fix-drift.yml +++ b/.github/workflows/fix-drift.yml @@ -119,7 +119,8 @@ jobs: if: success() && steps.check.outputs.skip != 'true' run: | npx tsx scripts/fix-drift.ts --create-pr 2>&1 | tee /tmp/pr-output.txt - PR_URL=$(grep -oP 'https://github.com/[^ ]+/pull/\d+' /tmp/pr-output.txt | head -1) + PR_URL=$(grep -oE 'https://github.com/[^ ]+/pull/[0-9]+' /tmp/pr-output.txt | head -1) + if [ -z "$PR_URL" ]; then echo "No PR URL found"; exit 1; fi echo "url=$PR_URL" >> $GITHUB_OUTPUT env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -129,7 +130,7 @@ jobs: if: success() && steps.pr.outputs.url != '' run: | PR_URL="${{ steps.pr.outputs.url }}" - gh pr merge "$PR_URL" --merge --admin + gh pr merge "$PR_URL" --merge --auto env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -144,15 +145,19 @@ jobs: - name: Notify Slack on fix success if: success() && steps.pr.outputs.url != '' run: | - curl -s -X POST "${{ secrets.SLACK_WEBHOOK }}" \ + if [ -z "$SLACK_WEBHOOK" ]; then echo "SLACK_WEBHOOK not set, skipping"; exit 0; fi + curl -s -X POST "$SLACK_WEBHOOK" \ -H "Content-Type: application/json" \ -d "{\"text\":\"✅ *Drift auto-fixed and merged*\nPR: ${{ steps.pr.outputs.url }}\nRun: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}\"}" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} # Step 8: Slack notification on fix failure - name: Notify Slack on fix failure if: failure() && steps.check.outputs.skip != 'true' run: | - curl -s -X POST "${{ secrets.SLACK_WEBHOOK }}" \ + if [ -z "$SLACK_WEBHOOK" ]; then echo "SLACK_WEBHOOK not set, skipping"; exit 0; fi + curl -s -X POST "$SLACK_WEBHOOK" \ -H "Content-Type: application/json" \ -d "{\"text\":\"❌ *Drift auto-fix failed* — issue created\nRun: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}\"}" env: diff --git a/.github/workflows/notify-pr.yml b/.github/workflows/notify-pr.yml index a143806..2cd80a2 100644 --- a/.github/workflows/notify-pr.yml +++ b/.github/workflows/notify-pr.yml @@ -7,11 +7,17 @@ jobs: runs-on: ubuntu-latest steps: - name: Notify Slack + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} + PR_TITLE: ${{ github.event.pull_request.title }} + PR_URL: ${{ github.event.pull_request.html_url }} + PR_AUTHOR: ${{ github.event.pull_request.user.login }} + PR_NUM: ${{ github.event.pull_request.number }} run: | - PR_TITLE="${{ github.event.pull_request.title }}" - PR_URL="${{ github.event.pull_request.html_url }}" - PR_AUTHOR="${{ github.event.pull_request.user.login }}" - PR_NUM="${{ github.event.pull_request.number }}" - curl -sf -X POST "${{ secrets.SLACK_WEBHOOK }}" \ + if [ -z "$SLACK_WEBHOOK" ]; then echo "SLACK_WEBHOOK not set, skipping"; exit 0; fi + PAYLOAD=$(jq -n \ + --arg text "New PR #${PR_NUM} on aimock by *${PR_AUTHOR}*: <${PR_URL}|${PR_TITLE}>" \ + '{text: $text}') + curl -sf -X POST "$SLACK_WEBHOOK" \ -H "Content-Type: application/json" \ - -d "{\"text\": \"🔀 New PR #${PR_NUM} on aimock by *${PR_AUTHOR}*: <${PR_URL}|${PR_TITLE}>\"}" + -d "$PAYLOAD" diff --git a/.github/workflows/publish-release.yml b/.github/workflows/publish-release.yml index c01c719..1012d27 100644 --- a/.github/workflows/publish-release.yml +++ b/.github/workflows/publish-release.yml @@ -96,7 +96,10 @@ jobs: - name: Notify Slack if: steps.check.outputs.published == 'false' run: | + if [ -z "$SLACK_WEBHOOK" ]; then echo "SLACK_WEBHOOK not set, skipping"; exit 0; fi VERSION="v${{ steps.check.outputs.version }}" - curl -s -X POST "${{ secrets.SLACK_WEBHOOK }}" \ + curl -s -X POST "$SLACK_WEBHOOK" \ -H "Content-Type: application/json" \ -d "{\"text\":\"📦 *@copilotkit/aimock ${VERSION} published*\nnpm: https://www.npmjs.com/package/@copilotkit/aimock/v/${{ steps.check.outputs.version }}\nRelease: https://github.com/${{ github.repository }}/releases/tag/${VERSION}\"}" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.github/workflows/test-drift.yml b/.github/workflows/test-drift.yml index 101190a..52ee746 100644 --- a/.github/workflows/test-drift.yml +++ b/.github/workflows/test-drift.yml @@ -79,7 +79,7 @@ jobs: - name: Notify Slack if: always() run: | - if [ -z "${{ secrets.SLACK_WEBHOOK }}" ]; then exit 0; fi + if [ -z "$SLACK_WEBHOOK" ]; then echo "SLACK_WEBHOOK not set, skipping"; exit 0; fi PREV="${{ steps.prev.outputs.conclusion }}" NOW="${{ job.status }}" @@ -102,7 +102,7 @@ jobs: exit 0 fi - curl -s -X POST "${{ secrets.SLACK_WEBHOOK }}" \ + curl -s -X POST "$SLACK_WEBHOOK" \ -H "Content-Type: application/json" \ -d "{\"text\": \"${EMOJI} ${MSG}\"}" env: diff --git a/.github/workflows/update-competitive-matrix.yml b/.github/workflows/update-competitive-matrix.yml index 2d60a08..37323fd 100644 --- a/.github/workflows/update-competitive-matrix.yml +++ b/.github/workflows/update-competitive-matrix.yml @@ -58,6 +58,7 @@ jobs: - name: Notify Slack if: always() run: | + if [ -z "$SLACK_WEBHOOK" ]; then echo "SLACK_WEBHOOK not set, skipping"; exit 0; fi if [ "${{ steps.changes.outputs.changed }}" = "true" ]; then EMOJI="📊" MSG="*Competitive matrix changes detected* — PR created with updated migration pages. " @@ -68,6 +69,8 @@ jobs: EMOJI="❌" MSG="*Competitive matrix update failed*. " fi - curl -s -X POST "${{ secrets.SLACK_WEBHOOK }}" \ + curl -s -X POST "$SLACK_WEBHOOK" \ -H "Content-Type: application/json" \ -d "{\"text\": \"${EMOJI} ${MSG}\"}" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/src/jest.ts b/src/jest.ts index c508feb..5134897 100644 --- a/src/jest.ts +++ b/src/jest.ts @@ -52,6 +52,8 @@ export interface AimockHandle { */ export function useAimock(options: UseAimockOptions = {}): () => AimockHandle { let handle: AimockHandle | null = null; + let origOpenaiUrl: string | undefined; + let origAnthropicUrl: string | undefined; beforeAll(async () => { const { fixtures: fixturePath, patchEnv, ...serverOpts } = options; @@ -68,6 +70,8 @@ export function useAimock(options: UseAimockOptions = {}): () => AimockHandle { const url = await llm.start(); if (patchEnv !== false) { + origOpenaiUrl = process.env.OPENAI_BASE_URL; + origAnthropicUrl = process.env.ANTHROPIC_BASE_URL; process.env.OPENAI_BASE_URL = `${url}/v1`; process.env.ANTHROPIC_BASE_URL = `${url}/v1`; } @@ -84,8 +88,10 @@ export function useAimock(options: UseAimockOptions = {}): () => AimockHandle { afterAll(async () => { if (handle) { if (options.patchEnv !== false) { - delete process.env.OPENAI_BASE_URL; - delete process.env.ANTHROPIC_BASE_URL; + if (origOpenaiUrl !== undefined) process.env.OPENAI_BASE_URL = origOpenaiUrl; + else delete process.env.OPENAI_BASE_URL; + if (origAnthropicUrl !== undefined) process.env.ANTHROPIC_BASE_URL = origAnthropicUrl; + else delete process.env.ANTHROPIC_BASE_URL; } await handle.llm.stop(); handle = null; @@ -107,7 +113,10 @@ function loadFixtures(fixturePath: string): Fixture[] { return loadFixturesFromDir(fixturePath); } return loadFixtureFile(fixturePath); - } catch { + } catch (err) { + console.warn( + `[aimock] Failed to load fixtures from ${fixturePath}: ${err instanceof Error ? err.message : String(err)}`, + ); return []; } } diff --git a/src/router.ts b/src/router.ts index f235d50..65f6be8 100644 --- a/src/router.ts +++ b/src/router.ts @@ -67,7 +67,11 @@ export function matchFixture( if (!compatible) continue; } - // userMessage — match against the last user message content + // userMessage — case-sensitive match against the last user message content. + // String matching is intentionally case-sensitive so fixture authors can + // rely on exact string values. This differs from the case-insensitive + // matchesPattern() in helpers.ts, which is used for search/rerank/moderation + // where exact casing rarely matters. if (match.userMessage !== undefined) { const msg = getLastMessageByRole(effective.messages, "user"); const text = msg ? getTextContent(msg.content) : null; @@ -79,6 +83,7 @@ export function matchFixture( if (!text.includes(match.userMessage)) continue; } } else { + match.userMessage.lastIndex = 0; if (!match.userMessage.test(text)) continue; } } @@ -96,7 +101,8 @@ export function matchFixture( if (!found) continue; } - // inputText — match against the embedding input text (used by embeddings endpoint) + // inputText — case-sensitive match against the embedding input text. + // Same rationale as userMessage above: fixture authors specify exact strings. if (match.inputText !== undefined) { const embeddingInput = effective.embeddingInput; if (!embeddingInput) continue; @@ -107,6 +113,7 @@ export function matchFixture( if (!embeddingInput.includes(match.inputText)) continue; } } else { + match.inputText.lastIndex = 0; if (!match.inputText.test(embeddingInput)) continue; } } @@ -122,6 +129,7 @@ export function matchFixture( if (typeof match.model === "string") { if (effective.model !== match.model) continue; } else { + match.model.lastIndex = 0; if (!match.model.test(effective.model)) continue; } } diff --git a/src/vitest.ts b/src/vitest.ts index 72a1b7c..0d55563 100644 --- a/src/vitest.ts +++ b/src/vitest.ts @@ -43,6 +43,8 @@ export interface AimockHandle { */ export function useAimock(options: UseAimockOptions = {}): () => AimockHandle { let handle: AimockHandle | null = null; + let origOpenaiUrl: string | undefined; + let origAnthropicUrl: string | undefined; beforeAll(async () => { const { fixtures: fixturePath, patchEnv, ...serverOpts } = options; @@ -59,6 +61,8 @@ export function useAimock(options: UseAimockOptions = {}): () => AimockHandle { const url = await llm.start(); if (patchEnv !== false) { + origOpenaiUrl = process.env.OPENAI_BASE_URL; + origAnthropicUrl = process.env.ANTHROPIC_BASE_URL; process.env.OPENAI_BASE_URL = `${url}/v1`; process.env.ANTHROPIC_BASE_URL = `${url}/v1`; } @@ -75,8 +79,10 @@ export function useAimock(options: UseAimockOptions = {}): () => AimockHandle { afterAll(async () => { if (handle) { if (options.patchEnv !== false) { - delete process.env.OPENAI_BASE_URL; - delete process.env.ANTHROPIC_BASE_URL; + if (origOpenaiUrl !== undefined) process.env.OPENAI_BASE_URL = origOpenaiUrl; + else delete process.env.OPENAI_BASE_URL; + if (origAnthropicUrl !== undefined) process.env.ANTHROPIC_BASE_URL = origAnthropicUrl; + else delete process.env.ANTHROPIC_BASE_URL; } await handle.llm.stop(); handle = null; @@ -98,7 +104,10 @@ function loadFixtures(fixturePath: string): Fixture[] { return loadFixturesFromDir(fixturePath); } return loadFixtureFile(fixturePath); - } catch { + } catch (err) { + console.warn( + `[aimock] Failed to load fixtures from ${fixturePath}: ${err instanceof Error ? err.message : String(err)}`, + ); return []; } } From 903acf9cdd506a58da026bc2c7f94fd02103e83a Mon Sep 17 00:00:00 2001 From: Jordan Ritter Date: Thu, 23 Apr 2026 16:23:37 -0700 Subject: [PATCH 6/6] chore: release v1.15.1 --- CHANGELOG.md | 42 ++++++++++++++++++++++++++---------------- package.json | 2 +- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07d1f1a..e5f051d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,23 +1,33 @@ # @copilotkit/aimock -## 1.15.0 +## 1.15.1 -### Added +### Fixed -- Chaos injection in proxy mode: drop and disconnect fire pre-flight (upstream is never - contacted), malformed proxies the request then corrupts the response body before - delivering it to the client. -- SSE streaming bypass: when upstream responds with `text/event-stream`, malformed chaos - is silently skipped (bytes are already on the wire). A bypass metric - (`aimock_chaos_bypassed_total`) is emitted so operators can see the configured action - did not fire; the normal `aimock_chaos_triggered_total` counter does not increment. -- Chaos source label (`fixture` vs `proxy`) on Prometheus metrics and journal entries, - distinguishing where the chaos decision was made. -- CORS `Access-Control-Allow-Headers` now includes `X-Aimock-Chaos-Drop`, - `X-Aimock-Chaos-Malformed`, `X-Aimock-Chaos-Disconnect`, and `X-Test-Id`, enabling - browser-based clients to send per-request chaos overrides via preflight-safe headers. -- `handleVideoStatus` (`GET /v1/videos/:id`) now evaluates chaos before returning video - state, consistent with all other handler endpoints. +- **Recorder**: crash hardening (headersSent guards, clientDisconnected tracking), + preserve content alongside toolCalls, Cohere v2 native detection, tool-call ID + extraction from 5 providers, reasoning/thinking extraction from 4 providers, + multi-block text join (filter+join instead of find), thinking-only and empty-content + response handling, Ollama /api/generate format detection, streaming collapse + reasoning propagation. +- **Bedrock/Converse**: ContentWithToolCallsResponse support, ResponseOverrides wired + into all non-streaming and streaming builders, Converse-wrapped stream event format, + text_delta type field on text deltas, proper error envelope on Converse errors, + webSearches warnings. +- **Cohere v2**: reasoning in all builders + streaming, webSearches warnings, + response_format forwarding, assistant tool_calls preservation, full + ResponseOverrides (finish_reason, usage, id) in non-streaming and streaming paths. +- **Server**: readBody 10MB size limit, control API error detail, one-shot error fixture + race fix, normalizeCompatPath clarity, fixtures_loaded gauge updates on mutations. +- **Competitive matrix**: HTML pipeline fixed (computeChanges, applyChanges, + updateProviderCounts, extractFeatures all aligned with actual DOM structure). +- **CI workflows**: --auto merge (respects branch protection), Slack secrets via env + vars, script injection prevention in notify-pr.yml, portable grep. +- **Router**: RegExp g-flag lastIndex reset prevents alternating match/no-match. +- **Jest/Vitest**: save/restore pre-existing env vars in afterAll, loadFixtures + console.warn on failure. +- **Gemini**: tool_call_id collision fix (shared callCounter), thought-part filtering. +- **Ollama**: ContentWithToolCallsResponse support, default stream:true, field validation. ## 1.14.9 diff --git a/package.json b/package.json index f59ddf8..06fecc0 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@copilotkit/aimock", - "version": "1.15.0", + "version": "1.15.1", "description": "Mock infrastructure for AI application testing — LLM APIs, image generation, text-to-speech, transcription, video generation, MCP tools, A2A agents, AG-UI event streams, vector databases, search, rerank, and moderation. One package, one port, zero dependencies.", "license": "MIT", "keywords": [