From f55ca945bcd53f6f1c791a9a6aec2d41e846f5cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=91=E5=B8=83=E6=9E=97?= <11641432+heiheiyouyou@user.noreply.gitee.com> Date: Wed, 6 May 2026 19:27:21 +0800 Subject: [PATCH] fix:topk input search --- .../adapters/openclaw/tools.ts | 51 ++++++++++++++++--- .../core/pipeline/memory-core.ts | 24 ++++++++- .../core/retrieval/injector.ts | 2 +- .../core/retrieval/retrieve.ts | 16 +++--- .../unit/adapters/openclaw-bridge.test.ts | 41 +++++++++++++++ 5 files changed, 120 insertions(+), 14 deletions(-) diff --git a/apps/memos-local-plugin/adapters/openclaw/tools.ts b/apps/memos-local-plugin/adapters/openclaw/tools.ts index b098904d1..2effab7dc 100644 --- a/apps/memos-local-plugin/adapters/openclaw/tools.ts +++ b/apps/memos-local-plugin/adapters/openclaw/tools.ts @@ -42,7 +42,28 @@ const DEFAULT_BODY_CAP = 1200; const MemorySearchParams = Type.Object({ query: Type.String({ minLength: 1, description: "Free-text query (2–5 key words)." }), - maxResults: Type.Optional(Type.Integer({ minimum: 1, maximum: 50, default: 10 })), + maxResults: Type.Optional(Type.Integer({ minimum: 1, maximum: 50 })), + tier1topK: Type.Optional( + Type.Integer({ + minimum: 0, + maximum: 100, + description: "Override Skill (Tier 1) topK for this search only.", + }), + ), + tier2topK: Type.Optional( + Type.Integer({ + minimum: 0, + maximum: 100, + description: "Override trace/episode (Tier 2) topK for this search only.", + }), + ), + tier3topK: Type.Optional( + Type.Integer({ + minimum: 0, + maximum: 100, + description: "Override world-model (Tier 3) topK for this search only.", + }), + ), sessionScope: Type.Optional( Type.Boolean({ default: false, @@ -147,16 +168,15 @@ export function registerOpenClawTools(api: OpenClawPluginApi, opts: ToolsOptions const started = Date.now(); const core = await resolveCore(opts); const sessionId = params.sessionScope ? sessionFromCtx(ctx) : undefined; + const maxResults = params.maxResults !== undefined + ? Math.min(params.maxResults, 50) + : undefined; const result = await core.searchMemory({ agent: opts.agent, namespace: namespaceFromCtx(ctx), sessionId: sessionId as never, query: params.query, - topK: { - tier1: Math.min(params.maxResults ?? 10, 50), - tier2: Math.min(params.maxResults ?? 10, 50), - tier3: Math.min(params.maxResults ?? 10, 50), - }, + topK: topKParams(params, maxResults), }); return { hits: result.hits.map((h) => ({ @@ -404,6 +424,25 @@ export function registerOpenClawTools(api: OpenClawPluginApi, opts: ToolsOptions ); } +function topKParams( + params: MemorySearchParamsT, + maxResults: number | undefined, +): { tier1?: number; tier2?: number; tier3?: number } | undefined { + if ( + params.tier1topK === undefined && + params.tier2topK === undefined && + params.tier3topK === undefined && + maxResults === undefined + ) { + return undefined; + } + return { + tier1: params.tier1topK ?? maxResults, + tier2: params.tier2topK ?? maxResults, + tier3: params.tier3topK ?? maxResults, + }; +} + /** Exposed for tests + documentation. */ export const TOOL_SCHEMAS = { memory_search: MemorySearchParams, diff --git a/apps/memos-local-plugin/core/pipeline/memory-core.ts b/apps/memos-local-plugin/core/pipeline/memory-core.ts index 5add1a331..60d37b9a7 100644 --- a/apps/memos-local-plugin/core/pipeline/memory-core.ts +++ b/apps/memos-local-plugin/core/pipeline/memory-core.ts @@ -93,6 +93,7 @@ import { ownerFromNamespace, isVisibleTo, } from "../runtime/namespace.js"; +import type { RetrievalConfig } from "../retrieval/types.js"; // ─── Public bootstrap helpers ─────────────────────────────────────────────── @@ -1739,8 +1740,10 @@ export function createMemoryCore( ensureLive(); const ns = query.namespace ?? activeNamespace; activeNamespace = ns; + const baseDeps = handle.retrievalDeps(); const deps = { - ...handle.retrievalDeps(), + ...baseDeps, + config: applyTopKOverride(baseDeps.config, query.topK), namespace: ns, repos: wrapRetrievalRepos(handle.repos, ns), }; @@ -3735,6 +3738,25 @@ export function inferTier( return 2; } +function applyTopKOverride( + config: RetrievalConfig, + topK: RetrievalQueryDTO["topK"] | undefined, +): RetrievalConfig { + if (!topK) return config; + return { + ...config, + tier1TopK: clampTopK(topK.tier1, config.tier1TopK), + tier2TopK: clampTopK(topK.tier2, config.tier2TopK), + tier3TopK: clampTopK(topK.tier3, config.tier3TopK), + }; +} + +function clampTopK(value: number | undefined, fallback: number): number { + if (value === undefined) return fallback; + if (!Number.isFinite(value)) return fallback; + return Math.min(Math.max(0, Math.trunc(value)), 100); +} + function eventTime(evt: unknown): number { const at = (evt as { at?: unknown } | null)?.at; return typeof at === "number" && Number.isFinite(at) ? at : Date.now(); diff --git a/apps/memos-local-plugin/core/retrieval/injector.ts b/apps/memos-local-plugin/core/retrieval/injector.ts index dfbc41f4e..a3e5bc048 100644 --- a/apps/memos-local-plugin/core/retrieval/injector.ts +++ b/apps/memos-local-plugin/core/retrieval/injector.ts @@ -417,7 +417,7 @@ const HEADER_BY_REASON: Record = { }; const FOOTER_LINES_COMMON: readonly string[] = [ - "- `memory_search(query, maxResults?)` — re-query with a shorter / rephrased string", + "- `memory_search(query, maxResults?, tier1topK?, tier2topK?, tier3topK?)` — re-query with a shorter / rephrased string", "- `memory_get(id, kind?)` — fetch a full trace / policy / world-model body by refId", "- `memory_timeline(episodeId, limit?)` — expand an episode into its step-by-step traces", ]; diff --git a/apps/memos-local-plugin/core/retrieval/retrieve.ts b/apps/memos-local-plugin/core/retrieval/retrieve.ts index 141e5ec71..07e795c43 100644 --- a/apps/memos-local-plugin/core/retrieval/retrieve.ts +++ b/apps/memos-local-plugin/core/retrieval/retrieve.ts @@ -224,9 +224,13 @@ async function runAll( const noUsableChannel = !queryVec && !haveKeywordChannel; // Kick off the tiers in parallel — each resolves to its own list. + const wantTier1 = plan.wantTier1 && deps.config.tier1TopK > 0; + const wantTier2 = plan.wantTier2 && deps.config.tier2TopK > 0; + const wantTier3 = plan.wantTier3 && deps.config.tier3TopK > 0; + const tier1Start = Date.now(); const tier1Promise: Promise = - plan.wantTier1 && !noUsableChannel + wantTier1 && !noUsableChannel ? runTier1( { repos: deps.repos, config: deps.config }, { @@ -241,7 +245,7 @@ async function runAll( const tier2Start = Date.now(); const tier2Promise: Promise<{ traces: TraceCandidate[]; episodes: EpisodeCandidate[] }> = - plan.wantTier2 && !noUsableChannel + wantTier2 && !noUsableChannel ? runTier2( { repos: deps.repos, config: deps.config, now: deps.now }, { @@ -257,7 +261,7 @@ async function runAll( const tier3Start = Date.now(); const tier3Promise: Promise = - plan.wantTier3 && !noUsableChannel + wantTier3 && !noUsableChannel ? runTier3( { repos: deps.repos, config: deps.config }, { @@ -274,9 +278,9 @@ async function runAll( tier3Promise, ]); - const tier1LatencyMs = plan.wantTier1 ? Date.now() - tier1Start : 0; - const tier2LatencyMs = plan.wantTier2 ? Date.now() - tier2Start : 0; - const tier3LatencyMs = plan.wantTier3 ? Date.now() - tier3Start : 0; + const tier1LatencyMs = wantTier1 ? Date.now() - tier1Start : 0; + const tier2LatencyMs = wantTier2 ? Date.now() - tier2Start : 0; + const tier3LatencyMs = wantTier3 ? Date.now() - tier3Start : 0; const fuseStart = Date.now(); const rawCandidateCount = diff --git a/apps/memos-local-plugin/tests/unit/adapters/openclaw-bridge.test.ts b/apps/memos-local-plugin/tests/unit/adapters/openclaw-bridge.test.ts index 3ec49c761..71258633c 100644 --- a/apps/memos-local-plugin/tests/unit/adapters/openclaw-bridge.test.ts +++ b/apps/memos-local-plugin/tests/unit/adapters/openclaw-bridge.test.ts @@ -1366,6 +1366,47 @@ describe("registerOpenClawTools", () => { expect(res.totalMs).toBeGreaterThanOrEqual(0); }); + it("memory_search maps per-tier topK params and keeps maxResults fallback", async () => { + const searchMemory = vi.fn(async () => ({ + hits: [], + injectedContext: "", + tierLatencyMs: { tier1: 0, tier2: 0, tier3: 0 }, + })); + const mc = { searchMemory } as unknown as MemoryCore; + + const { api, tools } = collectTools(); + registerOpenClawTools(api, { + agent: "openclaw", + core: mc, + log: silentLogger(), + }); + const search = tools.find((t) => t.descriptor.name === "memory_search")!; + + await search.descriptor.execute("toolCall_1", { + query: "anything", + maxResults: 7, + tier1topK: 2, + tier3topK: 0, + }); + expect(searchMemory).toHaveBeenLastCalledWith( + expect.objectContaining({ + query: "anything", + topK: { tier1: 2, tier2: 7, tier3: 0 }, + }), + ); + + await search.descriptor.execute("toolCall_2", { + query: "fallback", + maxResults: 4, + }); + expect(searchMemory).toHaveBeenLastCalledWith( + expect.objectContaining({ + query: "fallback", + topK: { tier1: 4, tier2: 4, tier3: 4 }, + }), + ); + }); + it("registers tool shells before the async core is resolved", async () => { const mc = buildCore(); await mc.init();