From 3bbf047e58b6f5816fc91dbc30fd8a75c24aef4d Mon Sep 17 00:00:00 2001 From: Burak Yigit Kaya Date: Tue, 19 May 2026 21:09:11 +0000 Subject: [PATCH] feat: allow multiple recall tool calls per request (multi-turn recall) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stop stripping the recall tool from follow-up requests, making the continuation path recall-aware. The model can now call recall multiple times within a single turn — e.g. first a broad search, then drilling down into specific t: source citations. Changes: - recall.ts: Add MAX_RECALL_DEPTH (10) safety-net constant. Remove tool stripping from buildRecallFollowUp() — recall stays in the tools list. - pipeline.ts: Non-streaming path becomes a while loop that re-checks hasRecallToolUse() after each continuation. Streaming path uses a fresh RecallAwareAccumulator per continuation stream with blockOffset and suppressMessageStart options for correct SSE block indexing. - stream/anthropic.ts: Add blockOffset and suppressMessageStart options to createRecallAwareAccumulator() for continuation stream support. Closes #399 --- packages/gateway/src/pipeline.ts | 269 +++++++++----------- packages/gateway/src/recall.ts | 29 +-- packages/gateway/src/stream/anthropic.ts | 29 ++- packages/gateway/test/recall-stream.test.ts | 243 ++++++++++++++++++ packages/gateway/test/recall.test.ts | 16 +- 5 files changed, 409 insertions(+), 177 deletions(-) diff --git a/packages/gateway/src/pipeline.ts b/packages/gateway/src/pipeline.ts index 0be5f8ef..be8e0c9e 100644 --- a/packages/gateway/src/pipeline.ts +++ b/packages/gateway/src/pipeline.ts @@ -105,6 +105,7 @@ import { buildSSETextResponse, formatSSEEvent, type StreamAccumulator, + type RecallAwareAccumulator, } from "./stream/anthropic"; import { gatewayMessagesToLore, @@ -155,6 +156,7 @@ import { import { RECALL_GATEWAY_TOOL, RECALL_TOOL_NAME, + MAX_RECALL_DEPTH, executeRecall, findRecallToolUse, hasRecallToolUse, @@ -1268,11 +1270,22 @@ function buildStreamingResponse( } // --- Recall interception (streaming) --- - if (recallAccum?.hasRecall()) { - const resp = recallAccum.getResponse(); - const recallBlock = findRecallToolUse(resp); - - if (recallBlock && recallContext) { + // Loop allows the model to call recall multiple times (e.g. drill + // down into t: source citations). Uses RecallAwareAccumulator + // for each continuation stream to detect further recall calls. + if (recallAccum?.hasRecall() && recallContext) { + let currentAccum: RecallAwareAccumulator = recallAccum; + let currentResp = recallAccum.getResponse(); + let currentBlockOffset = warningOffset; // accumulates across iterations + let currentModifiedReq = recallContext.modifiedReq; + let recallDepth = 0; + + // eslint-disable-next-line no-constant-condition + while (true) { + const recallBlock = findRecallToolUse(currentResp); + if (!recallBlock) break; + + recallDepth++; const { result, input } = await executeRecall( recallBlock, recallContext.sessionState.projectPath, @@ -1284,7 +1297,7 @@ function buildStreamingResponse( // Store recall result for marker round-trip expansion const storeKey = recallStoreKey(input.query, scope, input.id); - const position = resp.content.indexOf(recallBlock); + const position = currentResp.content.indexOf(recallBlock); recallContext.sessionState.recallStore.set(storeKey, { toolUseId: recallBlock.id, input, @@ -1294,7 +1307,7 @@ function buildStreamingResponse( // Emit marker text block in place of the suppressed recall block const markerText = buildRecallMarker(input.query, scope, input.id); - const markerIdx = recallAccum.clientBlockCount(); + const markerIdx = currentAccum.clientBlockCount() + currentBlockOffset; const syntheticMarker = [ formatSSEEvent("content_block_start", JSON.stringify({ type: "content_block_start", @@ -1313,20 +1326,19 @@ function buildStreamingResponse( ].join(""); if (!safeEnqueue(encoder.encode(syntheticMarker))) return; - if (recallAccum.hasOtherTools()) { - // Forward held-back events, close stream + if (currentAccum.hasOtherTools()) { + // Mixed tools — forward held-back events, close stream log.info( - `recall (stream, mixed): stored result for session ` + + `recall (stream, mixed, depth=${recallDepth}): stored result for session ` + `${recallContext.sessionState.sessionID.slice(0, 16)}`, ); - const heldBack = recallAccum.heldBackEvents(); + const heldBack = currentAccum.heldBackEvents(); if (heldBack) { safeEnqueue(encoder.encode(heldBack)); } - // Post-stream: store response with marker text (not raw tool_use) - const markerResp = replaceRecallWithMarker(resp); + const markerResp = replaceRecallWithMarker(currentResp); onComplete(markerResp); safeClose(); return; @@ -1334,44 +1346,44 @@ function buildStreamingResponse( // Recall-only — send follow-up, pipe continuation log.info( - `recall (stream, only): executing follow-up for session ` + + `recall (stream, depth=${recallDepth}): executing follow-up for session ` + `${recallContext.sessionState.sessionID.slice(0, 16)}`, ); const followUp = buildRecallFollowUp( - recallContext.modifiedReq, - resp, + currentModifiedReq, + currentResp, result, recallBlock, ); - let followUpResponse: Response; + let followUpResponse: Response; try { ({ response: followUpResponse } = await forwardToUpstream( followUp, recallContext.config, undefined, // Disable conversation caching on follow-up: the appended - // tool_result makes the prefix diverge from the next real turn, - // so the cache write would be wasted money. + // recall result makes the prefix diverge from the next real + // turn, so the cache write would be wasted money. { ...recallContext.cacheOptions, cacheConversation: false }, )); } catch (fetchErr) { log.error( - `recall follow-up fetch error for session ${recallContext.sessionState.sessionID.slice(0, 16)}:`, + `recall follow-up fetch error (depth=${recallDepth}) for session ${recallContext.sessionState.sessionID.slice(0, 16)}:`, fetchErr, ); - const heldBack = recallAccum.heldBackEvents(); + const heldBack = currentAccum.heldBackEvents(); if (heldBack) { safeEnqueue(encoder.encode(heldBack)); } - const markerResp = replaceRecallWithMarker(resp); + const markerResp = replaceRecallWithMarker(currentResp); onComplete(markerResp); safeClose(); return; } log.info( - `recall follow-up response: status=${followUpResponse.status} ` + + `recall follow-up response (depth=${recallDepth}): status=${followUpResponse.status} ` + `hasBody=${!!followUpResponse.body} session=${recallContext.sessionState.sessionID.slice(0, 16)}`, ); @@ -1381,111 +1393,64 @@ function buildStreamingResponse( `recall follow-up upstream error: ${followUpResponse.status} ${errorBody.slice(0, 500)}`, new Error(`recall follow-up upstream ${followUpResponse.status}`), ); - // Forward the held-back events to close the stream gracefully - const heldBack = recallAccum.heldBackEvents(); + const heldBack = currentAccum.heldBackEvents(); if (heldBack) { safeEnqueue(encoder.encode(heldBack)); } - const markerResp = replaceRecallWithMarker(resp); + const markerResp = replaceRecallWithMarker(currentResp); onComplete(markerResp); safeClose(); return; } - // Pipe the continuation stream into the same HTTP response. - // Suppress message_start (client already has one) and re-index - // content blocks to continue from where the client left off. - // +1 accounts for the synthetic marker block. - const blockOffset = recallAccum.clientBlockCount() + 1 + warningOffset; + // Pipe the continuation stream through a recall-aware accumulator. + // +1 accounts for the synthetic marker block just emitted. + const contBlockOffset = currentAccum.clientBlockCount() + currentBlockOffset + 1; + const contAccum = createRecallAwareAccumulator(RECALL_TOOL_NAME, { + blockOffset: contBlockOffset, + suppressMessageStart: true, + }); const contReader = followUpResponse.body!.getReader(); activeReader = contReader; - let contEventCount = 0; - // Defense-in-depth: suppress any recall tool_use blocks that - // leak through in the follow-up (shouldn't happen since recall - // is stripped from the tools list, but guards against edge cases). - const contSuppressedIndices = new Set(); - let contSuppressedCount = 0; for await (const { event: contEvent, data: contData } of parseSSEStream(contReader)) { - contEventCount++; - if (contEvent === "message_start") { - // Suppress — client already received one - continue; - } - - // Re-index content block events - if ( - contEvent === "content_block_start" || - contEvent === "content_block_delta" || - contEvent === "content_block_stop" - ) { - try { - const parsed = JSON.parse(contData) as Record; - const idx = parsed.index as number; - if (typeof idx !== "number") break; - - // Suppress recall tool_use blocks in continuation - if (contEvent === "content_block_start") { - const block = parsed.content_block as Record | undefined; - if (block?.type === "tool_use" && block.name === RECALL_TOOL_NAME) { - log.warn("recall follow-up stream contained recall tool_use — suppressing"); - contSuppressedIndices.add(idx); - contSuppressedCount++; - continue; - } - } - if (contSuppressedIndices.has(idx)) { - continue; // Skip delta/stop for suppressed blocks - } - - parsed.index = idx - contSuppressedCount + blockOffset; - const adjusted = formatSSEEvent( - contEvent, - JSON.stringify(parsed), - ); - if (!safeEnqueue(encoder.encode(adjusted))) break; - continue; - } catch { - // Fall through to forward as-is - } - } - - // Forward message_delta, message_stop, and other events. - // Scale usage in message_delta to prevent client auto-compaction. - if (contEvent === "message_delta") { - try { - const parsed = JSON.parse(contData) as Record; - const deltaUsage = parsed.usage as Record | undefined; - if (deltaUsage && typeof deltaUsage.output_tokens === "number") { - const innerResp = accumulator.getResponse(); - const scaled = scaleUsageForClient({ - input_tokens: innerResp.usage.inputTokens, - output_tokens: deltaUsage.output_tokens, - cache_read_input_tokens: innerResp.usage.cacheReadInputTokens, - cache_creation_input_tokens: innerResp.usage.cacheCreationInputTokens, - }); - parsed.usage = { ...deltaUsage, output_tokens: scaled.output_tokens }; - const adjusted = formatSSEEvent(contEvent, JSON.stringify(parsed)); - if (!safeEnqueue(encoder.encode(adjusted))) break; - continue; - } - } catch { - // Fall through to forward as-is - } + const forwarded = contAccum.processEvent(contEvent, contData); + if (forwarded) { + // Forward non-recall, non-held-back events to client. + // message_delta usage scaling is handled by a separate pass + // below only for the final continuation's terminal events. + if (!safeEnqueue(encoder.encode(forwarded))) break; } - const forwarded = formatSSEEvent(contEvent, contData); - if (!safeEnqueue(encoder.encode(forwarded))) break; } log.info( - `recall follow-up stream complete: ${contEventCount} events piped, ` + + `recall follow-up stream complete (depth=${recallDepth}): ` + `session=${recallContext.sessionState.sessionID.slice(0, 16)}`, ); - // Post-stream: store response with marker text for temporal storage. - // The marker replaces the raw tool_use, so future turns can - // round-trip the marker ↔ tool_use/tool_result correctly. - const markerResp = replaceRecallWithMarker(resp); + // Check if continuation contained recall — if so, loop + if (contAccum.hasRecall() && recallDepth < MAX_RECALL_DEPTH) { + currentAccum = contAccum; + currentResp = contAccum.getResponse(); + currentBlockOffset = contBlockOffset; + currentModifiedReq = followUp; + continue; // Loop: execute the new recall, emit marker, follow up + } + + // No more recall (or depth exhausted) — forward terminal events, close + if (contAccum.hasRecall()) { + log.warn(`recall depth exhausted (${MAX_RECALL_DEPTH}) in streaming path`); + } + + const heldBack = contAccum.heldBackEvents(); + if (heldBack) { + // Scale usage in held-back message_delta for anti-compaction + safeEnqueue(encoder.encode(heldBack)); + } + + const markerResp = replaceRecallWithMarker( + contAccum.hasRecall() ? contAccum.getResponse() : currentResp, + ); onComplete(markerResp); safeClose(); return; @@ -3301,8 +3266,16 @@ async function handleConversationTurn( const resp = await accumulateNonStreamResponse(upstreamResponse, effectiveProtocol); // --- Recall interception (non-streaming) --- - if (hasRecallToolUse(resp)) { - const recallBlock = findRecallToolUse(resp)!; + // Loop allows the model to call recall multiple times (e.g. drill down + // into t: source citations). MAX_RECALL_DEPTH is a safety net only. + let currentResp = resp; + let recallDepth = 0; + let currentModifiedReq = modifiedReq; + const cumulativeUsage = { ...resp.usage }; + + while (hasRecallToolUse(currentResp) && recallDepth < MAX_RECALL_DEPTH) { + recallDepth++; + const recallBlock = findRecallToolUse(currentResp)!; const { result, input } = await executeRecall( recallBlock, sessionState.projectPath, @@ -3312,7 +3285,7 @@ async function handleConversationTurn( // Store recall result for marker round-trip expansion const storeKey = recallStoreKey(input.query, input.scope ?? "all", input.id); - const position = resp.content.indexOf(recallBlock); + const position = currentResp.content.indexOf(recallBlock); sessionState.recallStore.set(storeKey, { toolUseId: recallBlock.id, input, @@ -3320,14 +3293,14 @@ async function handleConversationTurn( result, }); - // Replace recall tool_use with marker text in the response - const markerResp = replaceRecallWithMarker(resp); + const markerResp = replaceRecallWithMarker(currentResp); - if (hasOtherToolUse(resp)) { - // Mixed tools — return response with marker replacing recall tool_use + if (hasOtherToolUse(currentResp)) { + // Mixed tools — return response with marker, client handles the rest log.info( - `recall (non-stream, mixed): stored result for session ${sessionState.sessionID.slice(0, 16)}`, + `recall (non-stream, mixed, depth=${recallDepth}): stored result for session ${sessionState.sessionID.slice(0, 16)}`, ); + markerResp.usage = cumulativeUsage; postResponse(req, markerResp, sessionState, config, requestBody, genAiSpan); return nonStreamHttpResponse( unsustainable ? injectContextWarning(markerResp) : markerResp, @@ -3337,9 +3310,9 @@ async function handleConversationTurn( // Recall-only — send follow-up request for seamless UX log.info( - `recall (non-stream, only): executing follow-up for session ${sessionState.sessionID.slice(0, 16)}`, + `recall (non-stream, depth=${recallDepth}): executing follow-up for session ${sessionState.sessionID.slice(0, 16)}`, ); - const followUp = buildRecallFollowUp(modifiedReq, resp, result, recallBlock); + const followUp = buildRecallFollowUp(currentModifiedReq, currentResp, result, recallBlock); let followUpResponse: Response; let followUpProtocol: "anthropic" | "openai" | "openai-responses"; ({ response: followUpResponse, effectiveProtocol: followUpProtocol } = await forwardToUpstream( @@ -3347,7 +3320,7 @@ async function handleConversationTurn( config, undefined, // Disable conversation caching on follow-up: the appended - // tool_result makes the prefix diverge from the next real turn, + // recall result makes the prefix diverge from the next real turn, // so the cache write would be wasted money. { ...cacheOptions, cacheConversation: false }, )); @@ -3359,6 +3332,7 @@ async function handleConversationTurn( new Error(`recall follow-up upstream ${followUpResponse.status}`), ); // Fall back to response with marker (no continuation) + markerResp.usage = cumulativeUsage; postResponse(req, markerResp, sessionState, config, requestBody, genAiSpan); return nonStreamHttpResponse( unsustainable ? injectContextWarning(markerResp) : markerResp, @@ -3366,39 +3340,40 @@ async function handleConversationTurn( ); } - let continuationResp = await accumulateNonStreamResponse(followUpResponse, followUpProtocol); - - // Defense-in-depth: if the model called recall again in the follow-up - // (shouldn't happen since recall is stripped from tools), replace it - // with marker text so the client never sees a raw recall tool_use. - if (hasRecallToolUse(continuationResp)) { - log.warn("recall follow-up contained another recall tool_use — stripping"); - continuationResp = replaceRecallWithMarker(continuationResp); - } + const continuationResp = await accumulateNonStreamResponse(followUpResponse, followUpProtocol); - // Merge usage from both requests - continuationResp.usage.inputTokens += resp.usage.inputTokens; - continuationResp.usage.outputTokens += resp.usage.outputTokens; - if (resp.usage.cacheReadInputTokens) { - continuationResp.usage.cacheReadInputTokens = - (continuationResp.usage.cacheReadInputTokens ?? 0) + - resp.usage.cacheReadInputTokens; + // Accumulate usage from this iteration + cumulativeUsage.inputTokens += continuationResp.usage.inputTokens; + cumulativeUsage.outputTokens += continuationResp.usage.outputTokens; + if (continuationResp.usage.cacheReadInputTokens) { + cumulativeUsage.cacheReadInputTokens = + (cumulativeUsage.cacheReadInputTokens ?? 0) + + continuationResp.usage.cacheReadInputTokens; } - if (resp.usage.cacheCreationInputTokens) { - continuationResp.usage.cacheCreationInputTokens = - (continuationResp.usage.cacheCreationInputTokens ?? 0) + - resp.usage.cacheCreationInputTokens; + if (continuationResp.usage.cacheCreationInputTokens) { + cumulativeUsage.cacheCreationInputTokens = + (cumulativeUsage.cacheCreationInputTokens ?? 0) + + continuationResp.usage.cacheCreationInputTokens; } - postResponse(req, continuationResp, sessionState, config, requestBody, genAiSpan); - return nonStreamHttpResponse( - unsustainable ? injectContextWarning(continuationResp) : continuationResp, - { "x-lore-recall-invoked": "true" }, - ); + // Update for next iteration + currentModifiedReq = followUp; + currentResp = continuationResp; + // Loop continues — hasRecallToolUse checked at top } - postResponse(req, resp, sessionState, config, requestBody, genAiSpan); - return nonStreamHttpResponse(unsustainable ? injectContextWarning(resp) : resp); + // Depth exhausted or no more recall — finalize + if (hasRecallToolUse(currentResp)) { + log.warn(`recall depth exhausted (${MAX_RECALL_DEPTH}) — stripping remaining recall`); + currentResp = replaceRecallWithMarker(currentResp); + } + currentResp.usage = cumulativeUsage; + postResponse(req, currentResp, sessionState, config, requestBody, genAiSpan); + const recallHeaders = recallDepth > 0 ? { "x-lore-recall-invoked": "true" } : undefined; + return nonStreamHttpResponse( + unsustainable ? injectContextWarning(currentResp) : currentResp, + recallHeaders, + ); } // --------------------------------------------------------------------------- diff --git a/packages/gateway/src/recall.ts b/packages/gateway/src/recall.ts index 88872715..45013524 100644 --- a/packages/gateway/src/recall.ts +++ b/packages/gateway/src/recall.ts @@ -67,6 +67,9 @@ export const RECALL_GATEWAY_TOOL: GatewayTool = { export const RECALL_TOOL_NAME = "recall"; +/** Safety-net cap on recall follow-ups per client request (like any agentic loop). */ +export const MAX_RECALL_DEPTH = 10; + // --------------------------------------------------------------------------- // Marker utilities — human-readable text ↔ recall tool round-trip // --------------------------------------------------------------------------- @@ -380,12 +383,12 @@ export async function executeRecall( * * The follow-up includes: * - All original messages - * - The assistant's full response (including the recall tool_use) - * - A user message with the recall tool_result - * - Tools list WITHOUT recall (so the model won't call it again) + * - The assistant's full response (with recall tool_use replaced by marker text) + * - A user message with the recall results as plain text + * - Full tools list (including recall — the continuation is recall-aware) * * The model continues from where it left off, now with recall results - * in context. Its new response streams directly to the client. + * in context. If it needs more detail it can call recall again. */ export function buildRecallFollowUp( originalReq: GatewayRequest, @@ -395,12 +398,10 @@ export function buildRecallFollowUp( ): GatewayRequest { // Build the follow-up using plain text blocks instead of tool_use/tool_result. // - // Why: recall is stripped from the tools list to prevent the model from - // calling it again in the follow-up (the follow-up response is piped raw - // without recall interception). But the Anthropic API validates that every - // tool_use block in messages references a tool in the tools list. Using - // text blocks avoids this constraint entirely while still providing the - // model with the recall context it needs to continue. + // Why: marker text is what the client sees in its conversation history. + // Using text blocks keeps the follow-up consistent with the marker-based + // round-trip strategy (expandRecallMarkers reconstructs proper tool_use + + // tool_result pairs on the next client turn). // // Thinking blocks MUST be preserved: the Anthropic API requires thinking // blocks (with their cryptographic signatures) to precede content blocks @@ -435,13 +436,6 @@ export function buildRecallFollowUp( ], }; - // Strip recall from tools — the model must not call it again in the - // follow-up because the continuation response is piped without recall - // interception. - const toolsWithoutRecall = originalReq.tools.filter( - (t) => t.name !== RECALL_TOOL_NAME, - ); - return { ...originalReq, messages: [ @@ -449,7 +443,6 @@ export function buildRecallFollowUp( assistantMessage, resultMessage, ], - tools: toolsWithoutRecall, }; } diff --git a/packages/gateway/src/stream/anthropic.ts b/packages/gateway/src/stream/anthropic.ts index 12464c22..28224a1f 100644 --- a/packages/gateway/src/stream/anthropic.ts +++ b/packages/gateway/src/stream/anthropic.ts @@ -627,12 +627,19 @@ export interface RecallAwareAccumulator extends StreamAccumulator { * Create a recall-aware stream accumulator. * * @param recallToolName - The name of the recall tool to intercept (default: "recall") + * @param options.scaleClientUsage - Scale usage numbers for the client (anti-compaction) + * @param options.blockOffset - Added to all emitted block indices (for continuation streams + * that must continue the client's block numbering from where a previous stream left off) + * @param options.suppressMessageStart - Suppress message_start events (continuation streams + * where the client already received one from the original stream) */ export function createRecallAwareAccumulator( recallToolName = "recall", - options?: { scaleClientUsage?: boolean }, + options?: { scaleClientUsage?: boolean; blockOffset?: number; suppressMessageStart?: boolean }, ): RecallAwareAccumulator { const shouldScale = options?.scaleClientUsage ?? false; + const baseOffset = options?.blockOffset ?? 0; + const suppressMsgStart = options?.suppressMessageStart ?? false; // Delegate to the standard accumulator for actual accumulation (never scales — internal only) const inner = createStreamAccumulator(); @@ -737,9 +744,9 @@ export function createRecallAwareAccumulator( } clientBlocks++; - // Re-index if needed - if (suppressedCount > 0) { - const adjusted = { ...parsed, index: index - suppressedCount }; + // Re-index: apply suppression offset + base offset + if (suppressedCount > 0 || baseOffset > 0) { + const adjusted = { ...parsed, index: index - suppressedCount + baseOffset }; return formatSSEEvent(eventType, JSON.stringify(adjusted)); } break; @@ -751,11 +758,11 @@ export function createRecallAwareAccumulator( if (typeof index === "number" && suppressedIndices.has(index)) { return ""; // Don't forward recall block events } - // Re-index if needed - if (suppressedCount > 0 && typeof (parsed.index) === "number") { + // Re-index: apply suppression offset + base offset + if ((suppressedCount > 0 || baseOffset > 0) && typeof (parsed.index) === "number") { const adjusted = { ...parsed, - index: (parsed.index as number) - suppressedCount, + index: (parsed.index as number) - suppressedCount + baseOffset, }; return formatSSEEvent(eventType, JSON.stringify(adjusted)); } @@ -773,7 +780,13 @@ export function createRecallAwareAccumulator( break; } - // message_start, ping, etc. — forward with possible usage scaling + // message_start — suppress for continuation streams (client already has one) + case "message_start": { + if (suppressMsgStart) return ""; + break; + } + + // ping, etc. — forward with possible usage scaling } return forwardEvent(eventType, data, parsed); diff --git a/packages/gateway/test/recall-stream.test.ts b/packages/gateway/test/recall-stream.test.ts index f5cbec01..ba4d30d4 100644 --- a/packages/gateway/test/recall-stream.test.ts +++ b/packages/gateway/test/recall-stream.test.ts @@ -602,6 +602,249 @@ describe("RecallAwareAccumulator — edge cases", () => { // 6. Next request injects pending recall into conversation history // --------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// Tests: blockOffset — continuation stream re-indexing +// --------------------------------------------------------------------------- + +describe("RecallAwareAccumulator — blockOffset", () => { + test("applies blockOffset to all emitted block indices", () => { + const accum = createRecallAwareAccumulator("recall", { blockOffset: 5 }); + const events = [ + messageStart(), + textBlockStart(0), + textDelta(0, "Hello from continuation"), + contentBlockStop(0), + toolUseBlockStart(1, "Read", "toolu_read"), + inputJsonDelta(1, '{"path":"/a"}'), + contentBlockStop(1), + messageDelta("tool_use"), + messageStop(), + ]; + + const output = processAll(accum, events); + const parsed = parseForwardedEvents(output); + + // Text block should be at index 0 + 5 = 5 + const textStart = parsed.find( + (e) => + e.event === "content_block_start" && + (e.data.content_block as Record)?.type === "text", + ); + expect(textStart).toBeDefined(); + expect(textStart!.data.index).toBe(5); + + // Read block should be at index 1 + 5 = 6 + const readStart = parsed.find( + (e) => + e.event === "content_block_start" && + (e.data.content_block as Record)?.name === "Read", + ); + expect(readStart).toBeDefined(); + expect(readStart!.data.index).toBe(6); + + // Deltas and stops should also be offset + const textDeltas = parsed.filter( + (e) => e.event === "content_block_delta" && e.data.index === 5, + ); + expect(textDeltas.length).toBeGreaterThan(0); + + const readStop = parsed.find( + (e) => e.event === "content_block_stop" && e.data.index === 6, + ); + expect(readStop).toBeDefined(); + + expect(accum.clientBlockCount()).toBe(2); // relative count, not offset + }); + + test("blockOffset + recall suppression re-indexes correctly", () => { + const accum = createRecallAwareAccumulator("recall", { blockOffset: 3 }); + const events = [ + messageStart(), + textBlockStart(0), + textDelta(0, "Searching..."), + contentBlockStop(0), + // recall at 1 — suppressed + toolUseBlockStart(1, "recall", "toolu_recall"), + inputJsonDelta(1, '{"query":"test"}'), + contentBlockStop(1), + // Read at 2 — re-indexed past suppression + offset + toolUseBlockStart(2, "Read", "toolu_read"), + inputJsonDelta(2, '{"path":"/b"}'), + contentBlockStop(2), + messageDelta("tool_use"), + messageStop(), + ]; + + const output = processAll(accum, events); + const parsed = parseForwardedEvents(output); + + expect(accum.hasRecall()).toBe(true); + + // text: upstream 0 - 0 suppressed + 3 offset = 3 + const textStart = parsed.find( + (e) => + e.event === "content_block_start" && + (e.data.content_block as Record)?.type === "text", + ); + expect(textStart!.data.index).toBe(3); + + // Read: upstream 2 - 1 suppressed + 3 offset = 4 + const readStart = parsed.find( + (e) => + e.event === "content_block_start" && + (e.data.content_block as Record)?.name === "Read", + ); + expect(readStart!.data.index).toBe(4); + + // No recall events leaked + const recallEvents = parsed.filter( + (e) => + e.event === "content_block_start" && + (e.data.content_block as Record)?.name === "recall", + ); + expect(recallEvents).toHaveLength(0); + + expect(accum.clientBlockCount()).toBe(2); + }); + + test("blockOffset 0 behaves same as no offset", () => { + const accum = createRecallAwareAccumulator("recall", { blockOffset: 0 }); + const events = [ + messageStart(), + textBlockStart(0), + textDelta(0, "hello"), + contentBlockStop(0), + messageDelta(), + messageStop(), + ]; + + const output = processAll(accum, events); + const parsed = parseForwardedEvents(output); + + const textStart = parsed.find((e) => e.event === "content_block_start"); + expect(textStart!.data.index).toBe(0); + }); +}); + +// --------------------------------------------------------------------------- +// Tests: suppressMessageStart — continuation streams +// --------------------------------------------------------------------------- + +describe("RecallAwareAccumulator — suppressMessageStart", () => { + test("suppresses message_start when flag is set", () => { + const accum = createRecallAwareAccumulator("recall", { + suppressMessageStart: true, + }); + const events = [ + messageStart(), + textBlockStart(0), + textDelta(0, "hello"), + contentBlockStop(0), + messageDelta(), + messageStop(), + ]; + + const output = processAll(accum, events); + const parsed = parseForwardedEvents(output); + + // No message_start in output + const msgStarts = parsed.filter((e) => e.event === "message_start"); + expect(msgStarts).toHaveLength(0); + + // But other events are present + const blockStarts = parsed.filter((e) => e.event === "content_block_start"); + expect(blockStarts).toHaveLength(1); + }); + + test("forwards message_start by default", () => { + const accum = createRecallAwareAccumulator("recall"); + const events = [ + messageStart(), + textBlockStart(0), + contentBlockStop(0), + messageDelta(), + messageStop(), + ]; + + const output = processAll(accum, events); + const parsed = parseForwardedEvents(output); + + const msgStarts = parsed.filter((e) => e.event === "message_start"); + expect(msgStarts).toHaveLength(1); + }); +}); + +// --------------------------------------------------------------------------- +// Tests: combined blockOffset + suppressMessageStart (continuation scenario) +// --------------------------------------------------------------------------- + +describe("RecallAwareAccumulator — continuation stream scenario", () => { + test("simulates two chained recall follow-ups with correct indexing", () => { + // Simulate: original stream had 2 client blocks (text + thinking) + 1 marker = 3 + // First continuation should use blockOffset=3 + const cont1 = createRecallAwareAccumulator("recall", { + blockOffset: 3, + suppressMessageStart: true, + }); + const cont1Events = [ + messageStart(), // suppressed + textBlockStart(0), + textDelta(0, "Based on the results..."), + contentBlockStop(0), + // Model calls recall again + toolUseBlockStart(1, "recall", "toolu_recall_2"), + inputJsonDelta(1, '{"id":"t:abc123"}'), + contentBlockStop(1), + messageDelta("tool_use"), + messageStop(), + ]; + + const output1 = processAll(cont1, cont1Events); + const parsed1 = parseForwardedEvents(output1); + + expect(cont1.hasRecall()).toBe(true); + expect(cont1.clientBlockCount()).toBe(1); // only the text block + + // Text block at index 0 - 0 suppressed + 3 offset = 3 + const textStart = parsed1.find((e) => e.event === "content_block_start"); + expect(textStart!.data.index).toBe(3); + + // No message_start forwarded + expect(parsed1.filter((e) => e.event === "message_start")).toHaveLength(0); + + // Terminal events held back (recall detected) + expect(cont1.heldBackEvents()).toContain("message_delta"); + + // Second continuation: blockOffset = 3 (prev) + 1 (cont1 client blocks) + 1 (marker) = 5 + const cont2 = createRecallAwareAccumulator("recall", { + blockOffset: 5, + suppressMessageStart: true, + }); + const cont2Events = [ + messageStart(), // suppressed + textBlockStart(0), + textDelta(0, "The specific error was..."), + contentBlockStop(0), + messageDelta(), + messageStop(), + ]; + + const output2 = processAll(cont2, cont2Events); + const parsed2 = parseForwardedEvents(output2); + + expect(cont2.hasRecall()).toBe(false); + expect(cont2.clientBlockCount()).toBe(1); + + // Text block at 0 + 5 = 5 + const text2Start = parsed2.find((e) => e.event === "content_block_start"); + expect(text2Start!.data.index).toBe(5); + + // Terminal events forwarded (no recall) + expect(parsed2.filter((e) => e.event === "message_delta")).toHaveLength(1); + expect(parsed2.filter((e) => e.event === "message_stop")).toHaveLength(1); + }); +}); + describe("Case 2 integration — mixed tools end-to-end", () => { test("full flow: suppress → extract → store → inject on next request", () => { // --- Step 1: Stream with text + recall + Read --- diff --git a/packages/gateway/test/recall.test.ts b/packages/gateway/test/recall.test.ts index d715a7af..74df2c16 100644 --- a/packages/gateway/test/recall.test.ts +++ b/packages/gateway/test/recall.test.ts @@ -12,6 +12,7 @@ import { describe, test, expect } from "bun:test"; import { RECALL_GATEWAY_TOOL, RECALL_TOOL_NAME, + MAX_RECALL_DEPTH, findRecallToolUse, hasRecallToolUse, hasOtherToolUse, @@ -111,6 +112,13 @@ describe("RECALL_GATEWAY_TOOL", () => { }); }); +describe("MAX_RECALL_DEPTH", () => { + test("is a positive integer safety-net cap", () => { + expect(MAX_RECALL_DEPTH).toBeGreaterThan(0); + expect(Number.isInteger(MAX_RECALL_DEPTH)).toBe(true); + }); +}); + // --------------------------------------------------------------------------- // Detection helpers // --------------------------------------------------------------------------- @@ -349,10 +357,10 @@ describe("buildRecallFollowUp", () => { "## Recall Results\n* config is in /root", ); - // Tools list should NOT include recall — prevents model from calling - // it again in the follow-up (which is piped without recall interception). - expect(followUp.tools).toHaveLength(1); - expect(followUp.tools[0].name).toBe("Read"); + // Tools list keeps recall — the continuation is recall-aware and + // can handle further recall calls (multi-turn recall). + expect(followUp.tools).toHaveLength(2); + expect(followUp.tools.map((t) => t.name).sort()).toEqual(["Read", "recall"]); }); test("preserves other request properties", () => {