From 36fd1d58c79906a2f0b28021d149a54427cc1f60 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sun, 10 May 2026 10:36:56 +0800 Subject: [PATCH 1/5] =?UTF-8?q?fix(memo):memo=E5=88=9D=E6=AD=A5=E4=BF=AE?= =?UTF-8?q?=E7=90=86,run=E6=88=AA=E5=8F=96=EF=BC=8Cllm=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E6=97=B6=E5=8E=BB=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap_test.go | 42 +++++++ internal/context/projection.go | 36 ++++++ internal/context/projection_test.go | 44 ++++++++ internal/memo/auto_extractor.go | 64 +++++++++++ internal/memo/auto_extractor_test.go | 144 ++++++++++++++++++++++++ internal/memo/llm_extractor.go | 123 ++++++++++++++++++--- internal/memo/llm_extractor_test.go | 61 +++++++++-- internal/memo/service.go | 158 +++++++++++++++++++++++++++ internal/memo/types.go | 38 +++++++ internal/runtime/memo.go | 24 ++++ internal/runtime/run.go | 2 +- internal/runtime/runtime.go | 2 +- internal/runtime/runtime_test.go | 58 ++++++++++ internal/runtime/session_mutation.go | 23 +++- internal/runtime/state.go | 2 + 15 files changed, 793 insertions(+), 28 deletions(-) diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index 3760dded..917823cb 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -1675,6 +1675,48 @@ func TestNewMemoExtractorAdapterBuildsProviderSafeMemoWindow(t *testing.T) { } } +func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) { + t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token") + cfg := config.StaticDefaults().Clone() + cfg.SelectedProvider = config.OpenAIName + cfg.Memo.ExtractRecentMessages = 3 + manager := config.NewManager(config.NewLoader("", &cfg)) + + providerStub := &stubMemoProvider{ + generate: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + if len(req.Messages) != 12 { + t.Fatalf("unexpected memo window length %d, want full run: %+v", len(req.Messages), req.Messages) + } + events <- providertypes.NewTextDeltaStreamEvent(`[]`) + events <- providertypes.NewMessageDoneStreamEvent("stop", nil) + return nil + }, + } + factory := &stubMemoProviderFactory{provider: providerStub} + scheduler := &stubMemoExtractorScheduler{} + extractor := newMemoExtractorAdapter(factory, manager, scheduler) + + inputMessages := make([]providertypes.Message, 0, 12) + for index := 0; index < 12; index++ { + inputMessages = append(inputMessages, providertypes.Message{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(fmt.Sprintf("message-%02d", index))}, + }) + } + extractor.Schedule("session-1", inputMessages) + if !scheduler.called || scheduler.extractor == nil { + t.Fatalf("expected scheduler to receive extractor") + } + + _, err := scheduler.extractor.Extract(context.Background(), inputMessages) + if err != nil { + t.Fatalf("extractor.Extract() error = %v", err) + } + if !factory.called { + t.Fatalf("expected provider factory Build to be called") + } +} + func TestNewMemoExtractorAdapterKeepsScheduledConfigSnapshot(t *testing.T) { t.Setenv(config.OpenAIDefaultAPIKeyEnv, "openai-token") t.Setenv(config.QiniuDefaultAPIKeyEnv, "qiniu-token") diff --git a/internal/context/projection.go b/internal/context/projection.go index ac7da4b3..298f7e8c 100644 --- a/internal/context/projection.go +++ b/internal/context/projection.go @@ -79,6 +79,42 @@ func BuildRecentMessagesForModel(messages []providertypes.Message, limit int) [] return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) } +// BuildMemoExtractionMessagesForModel 构造完整 run 的 provider-safe 记忆提取上下文。 +func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []providertypes.Message { + if len(messages) == 0 { + return nil + } + + keep := make([]bool, len(messages)) + for index := 0; index < len(messages); index++ { + message := messages[index] + if message.Role == providertypes.RoleTool { + continue + } + + if message.Role == providertypes.RoleAssistant && len(message.ToolCalls) > 0 { + for _, spanIndex := range matchedToolCallSpan(messages, index) { + keep[spanIndex] = true + } + continue + } + + keep[index] = true + } + + selected := make([]providertypes.Message, 0, len(messages)) + for index, message := range messages { + if keep[index] { + selected = append(selected, message) + } + } + if len(selected) == 0 { + return nil + } + + return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) +} + // matchedToolCallSpan 返回 assistant tool call 与其完整 tool 响应组成的合法窗口下标集合。 func matchedToolCallSpan(messages []providertypes.Message, assistantIndex int) []int { if assistantIndex < 0 || assistantIndex >= len(messages) { diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go index 372e3c3f..3b0c932a 100644 --- a/internal/context/projection_test.go +++ b/internal/context/projection_test.go @@ -140,6 +140,50 @@ func TestBuildRecentMessagesForModelKeepsOnlyRecentValidAnchors(t *testing.T) { } } +func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T) { + t.Parallel() + + messages := []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}, + {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan")}}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")}, + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-missing", Name: "bash", Arguments: `{}`}, + }, + }, + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("last")}}, + } + + projected := BuildMemoExtractionMessagesForModel(messages) + if len(projected) != 4 { + t.Fatalf("len(projected) = %d, want 4: %+v", len(projected), projected) + } + if renderDisplayParts(projected[0].Parts) != "first" || renderDisplayParts(projected[3].Parts) != "last" { + t.Fatalf("expected full run user messages to remain, got %+v", projected) + } + if projected[1].Role != providertypes.RoleAssistant || len(projected[1].ToolCalls) != 1 { + t.Fatalf("expected complete assistant tool span, got %+v", projected[1]) + } + if projected[2].Role != providertypes.RoleTool || + !strings.Contains(renderDisplayParts(projected[2].Parts), "tool result") || + projected[2].ToolMetadata != nil { + t.Fatalf("expected projected tool result, got %+v", projected[2]) + } +} + func TestBuildRecentMessagesForModelRespectsAbsoluteMessageBudget(t *testing.T) { t.Parallel() diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go index aaadfec9..a1ccb43b 100644 --- a/internal/memo/auto_extractor.go +++ b/internal/memo/auto_extractor.go @@ -225,6 +225,24 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout) defer cancel() + if decisionExtractor, ok := extractor.(DecisionExtractor); ok { + existing, err := a.svc.autoExtractionCandidates(ctx) + if err != nil { + a.logError("memo: auto extract load candidates failed: %v", err) + return false + } + decisions, err := decisionExtractor.ExtractDecisions(ctx, messages, existing) + if err != nil { + if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { + a.logError("memo: auto extract skipped (protocol_mismatch): %v", err) + return true + } + a.logError("memo: auto extract failed: %v", err) + return false + } + return a.applyExtractionDecisions(ctx, decisions) + } + entries, err := extractor.Extract(ctx, messages) if err != nil { if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { @@ -265,6 +283,52 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider return succeeded } +// applyExtractionDecisions 应用模型返回的 create/update/skip 决策,并保留本地精确去重兜底。 +func (a *AutoExtractor) applyExtractionDecisions(ctx context.Context, decisions []ExtractionDecision) bool { + if len(decisions) == 0 { + return true + } + + seenCreates := make(map[string]struct{}, len(decisions)) + succeeded := true + for _, decision := range decisions { + switch decision.Action { + case ExtractionActionCreate: + entry := decision.Entry + entry.Source = SourceAutoExtract + key := autoExtractDedupKey(entry) + if key == "" { + continue + } + if _, exists := seenCreates[key]; exists { + continue + } + added, err := a.svc.addAutoExtractIfAbsent(ctx, entry) + if err != nil { + a.logError("memo: auto extract add failed: %v", err) + succeeded = false + continue + } + seenCreates[key] = struct{}{} + if !added { + continue + } + case ExtractionActionUpdate: + entry := decision.Entry + entry.Source = SourceAutoExtract + _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, entry) + if err != nil { + a.logError("memo: auto extract update failed: %v", err) + succeeded = false + continue + } + case ExtractionActionSkip: + continue + } + } + return succeeded +} + // autoExtractDedupKey 生成自动提取条目的精确去重键。 func autoExtractDedupKey(entry Entry) string { title := NormalizeTitle(entry.Title) diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go index 2f188d07..6c4e5835 100644 --- a/internal/memo/auto_extractor_test.go +++ b/internal/memo/auto_extractor_test.go @@ -38,6 +38,33 @@ func (s *stubMemoExtractor) Calls() int { return s.callCount } +type stubDecisionMemoExtractor struct { + mu sync.Mutex + callCount int + candidates []ExtractionCandidate + decisions []ExtractionDecision + err error +} + +func (s *stubDecisionMemoExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { + return nil, nil +} + +func (s *stubDecisionMemoExtractor) ExtractDecisions( + ctx context.Context, + messages []providertypes.Message, + existing []ExtractionCandidate, +) ([]ExtractionDecision, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.callCount++ + s.candidates = append([]ExtractionCandidate(nil), existing...) + if s.err != nil { + return nil, s.err + } + return append([]ExtractionDecision(nil), s.decisions...), nil +} + func newAutoExtractorTestService(t *testing.T) *Service { t.Helper() store := NewFileStore(t.TempDir(), t.TempDir()) @@ -214,6 +241,123 @@ func TestAutoExtractorSuppressesExactDuplicates(t *testing.T) { }) } +func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing.T) { + svc := newAutoExtractorTestService(t) + if err := svc.Add(context.Background(), Entry{ + Type: TypeFeedback, + Title: "测试策略", + Content: "用户要求修改后先跑测试。", + Source: SourceAutoExtract, + }); err != nil { + t.Fatalf("seed auto Add() error = %v", err) + } + if err := svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "语言偏好", + Content: "用户手动保存偏好中文回复。", + Source: SourceUserManual, + }); err != nil { + t.Fatalf("seed manual Add() error = %v", err) + } + + candidates, err := svc.autoExtractionCandidates(context.Background()) + if err != nil { + t.Fatalf("autoExtractionCandidates() error = %v", err) + } + var autoRef, manualRef string + for _, candidate := range candidates { + switch candidate.Source { + case SourceAutoExtract: + autoRef = candidate.Ref + case SourceUserManual: + manualRef = candidate.Ref + } + } + if autoRef == "" || manualRef == "" { + t.Fatalf("expected refs, auto=%q manual=%q candidates=%+v", autoRef, manualRef, candidates) + } + + extractor := &stubDecisionMemoExtractor{ + decisions: []ExtractionDecision{ + { + Action: ExtractionActionUpdate, + Ref: autoRef, + Entry: Entry{ + Title: "测试策略", + Content: "用户要求修改后先跑相关测试。", + Keywords: []string{"test"}, + }, + }, + { + Action: ExtractionActionUpdate, + Ref: manualRef, + Entry: Entry{ + Title: "语言偏好", + Content: "不应覆盖手动记忆。", + }, + }, + }, + } + auto := NewAutoExtractor(extractor, svc, time.Second) + auto.debounce = 5 * time.Millisecond + auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) + + auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("更新测试策略")}}}) + waitFor(t, time.Second, func() bool { + extractor.mu.Lock() + defer extractor.mu.Unlock() + return extractor.callCount == 1 + }) + + recall, err := svc.Recall(context.Background(), "测试策略", ScopeProject) + if err != nil { + t.Fatalf("Recall(auto) error = %v", err) + } + if len(recall) != 1 || !strings.Contains(recall[0].Content, "相关测试") { + t.Fatalf("expected auto memory to be updated, got %+v", recall) + } + manualRecall, err := svc.Recall(context.Background(), "语言偏好", ScopeUser) + if err != nil { + t.Fatalf("Recall(manual) error = %v", err) + } + if len(manualRecall) != 1 || strings.Contains(manualRecall[0].Content, "不应覆盖") { + t.Fatalf("manual memory should not be overwritten, got %+v", manualRecall) + } + if len(extractor.candidates) != 2 { + t.Fatalf("expected existing candidates to be provided, got %+v", extractor.candidates) + } +} + +func TestAutoExtractorSemanticCreateStillUsesExactDedup(t *testing.T) { + svc := newAutoExtractorTestService(t) + if err := svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "中文回复", + Content: "用户偏好中文回复。", + Source: SourceAutoExtract, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + extractor := &stubDecisionMemoExtractor{ + decisions: []ExtractionDecision{ + {Action: ExtractionActionCreate, Entry: Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}}, + {Action: ExtractionActionCreate, Entry: Entry{Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}}, + }, + } + auto := NewAutoExtractor(extractor, svc, time.Second) + auto.debounce = 5 * time.Millisecond + auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) + + auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("dedupe")}}}) + waitFor(t, time.Second, func() bool { + entries, err := svc.List(context.Background(), ScopeAll) + return err == nil && len(entries) == 2 + }) +} + func TestAutoExtractorUsesTimeoutContext(t *testing.T) { svc := newAutoExtractorTestService(t) extractor := &stubMemoExtractor{ diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go index 9d039f44..621446de 100644 --- a/internal/memo/llm_extractor.go +++ b/internal/memo/llm_extractor.go @@ -27,6 +27,8 @@ type LLMExtractor struct { } type extractedEntry struct { + Action string `json:"action"` + Ref string `json:"ref"` Type string `json:"type"` Title string `json:"title"` Content string `json:"content"` @@ -45,8 +47,27 @@ func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtrac } } -// Extract 从最近对话中提取可跨会话持久化的记忆条目。 +// Extract 从当前 run 对话中提取可跨会话持久化的新增记忆条目。 func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { + decisions, err := e.ExtractDecisions(ctx, messages, nil) + if err != nil { + return nil, err + } + entries := make([]Entry, 0, len(decisions)) + for _, decision := range decisions { + if decision.Action == ExtractionActionCreate { + entries = append(entries, decision.Entry) + } + } + return entries, nil +} + +// ExtractDecisions 从当前 run 对话中提取记忆,并结合既有记忆输出新增、合并或跳过决策。 +func (e *LLMExtractor) ExtractDecisions( + ctx context.Context, + messages []providertypes.Message, + existing []ExtractionCandidate, +) ([]ExtractionDecision, error) { if err := ctx.Err(); err != nil { return nil, err } @@ -54,12 +75,12 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes return nil, errors.New("memo: text generator is nil") } - recent := agentcontext.BuildRecentMessagesForModel(messages, e.recentMessageLimit) - if len(recent) == 0 || !containsUserMessage(recent) { + runMessages := agentcontext.BuildMemoExtractionMessagesForModel(messages) + if len(runMessages) == 0 || !containsUserMessage(runMessages) { return nil, nil } - response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now()), recent) + response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now(), existing), runMessages) if err != nil { return nil, err } @@ -77,23 +98,29 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes return nil, fmt.Errorf("memo: parse extraction response: %w", err) } - entries := make([]Entry, 0, len(extracted)) + decisions := make([]ExtractionDecision, 0, len(extracted)) for _, item := range extracted { - entry, ok := toMemoEntry(item) + decision, ok := toExtractionDecision(item) if !ok { continue } - entries = append(entries, entry) + decisions = append(decisions, decision) } - return entries, nil + return decisions, nil } // buildExtractionPrompt 构造记忆提取专用的 system prompt。 -func buildExtractionPrompt(now time.Time) string { +func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string { currentDate := now.In(time.Local).Format("2006-01-02") + existingJSON := "[]" + if len(existing) > 0 { + if data, err := json.Marshal(existing); err == nil { + existingJSON = string(data) + } + } return strings.TrimSpace(fmt.Sprintf(` 你是一个记忆提取助手(memory extraction assistant)。 -分析最近对话中值得跨会话持久记住的信息,并返回严格 JSON 数组。 +分析当前 run 对话中值得跨会话持久记住的信息,并结合既有记忆完成语义去重,返回严格 JSON 数组。 当前本地日期:%s 如果对话中出现“今天、明天、下周二”等相对日期,必须先转换为绝对日期再写入 content。 @@ -109,11 +136,18 @@ func buildExtractionPrompt(now time.Time) string { 2. 不要提取通用编程知识、代码结构、文件路径、Git 历史。 3. 每条记忆必须具体、可操作。 4. 没有值得记住的信息时,返回 []。 -5. 输出必须是 JSON 数组,不要输出任何额外解释。 +5. 如果新信息与既有记忆语义相同或只是轻微改写,输出 action="skip"。 +6. 如果新信息能补充或修正既有 source="extractor_auto" 的记忆,输出 action="update" 并填写目标 ref。 +7. 不允许 update source 不是 "extractor_auto" 的既有记忆;这类相近内容只能 skip。 +8. 如果是全新的可持久化信息,输出 action="create"。 +9. 输出必须是 JSON 数组,不要输出任何额外解释。 + +既有记忆候选(JSON): +%s 输出格式: -[{"type":"user","title":"...","content":"...","keywords":["..."]}] -`, currentDate)) +[{"action":"create","type":"user","title":"...","content":"...","keywords":["..."]},{"action":"update","ref":"project:p.md","title":"...","content":"...","keywords":["..."]},{"action":"skip","ref":"user:u.md"}] +`, currentDate, existingJSON)) } // containsUserMessage 检查待提取消息中是否包含用户输入。 @@ -148,6 +182,69 @@ func toMemoEntry(item extractedEntry) (Entry, bool) { }, true } +// toExtractionDecision 将 LLM 输出收敛为自动提取持久化决策。 +func toExtractionDecision(item extractedEntry) (ExtractionDecision, bool) { + action := parseExtractionAction(item.Action) + if action == "" { + return ExtractionDecision{}, false + } + if action == ExtractionActionSkip { + return ExtractionDecision{ + Action: action, + Ref: strings.TrimSpace(item.Ref), + }, true + } + + if action == ExtractionActionUpdate { + ref := strings.TrimSpace(item.Ref) + if ref == "" { + return ExtractionDecision{}, false + } + entry, ok := toMemoUpdateEntry(item) + if !ok { + return ExtractionDecision{}, false + } + return ExtractionDecision{Action: action, Ref: ref, Entry: entry}, true + } + + entry, ok := toMemoEntry(item) + if !ok { + return ExtractionDecision{}, false + } + return ExtractionDecision{Action: action, Entry: entry}, true +} + +// toMemoUpdateEntry 将 update 决策中的可变字段收敛为 Entry 片段。 +func toMemoUpdateEntry(item extractedEntry) (Entry, bool) { + title := NormalizeTitle(item.Title) + content := strings.TrimSpace(item.Content) + if title == "" || content == "" { + return Entry{}, false + } + return Entry{ + Title: title, + Content: content, + Keywords: normalizeKeywords(item.Keywords), + Source: SourceAutoExtract, + }, true +} + +// parseExtractionAction 解析模型决策动作,并兼容旧格式中缺省 action 的 create 输出。 +func parseExtractionAction(action string) ExtractionAction { + switch ExtractionAction(strings.ToLower(strings.TrimSpace(action))) { + case "": + return ExtractionActionCreate + case ExtractionActionCreate: + return ExtractionActionCreate + case ExtractionActionUpdate: + return ExtractionActionUpdate + case ExtractionActionSkip: + return ExtractionActionSkip + default: + return "" + } +} + // normalizeKeywords 规范化关键词列表,移除空值和重复值。 func normalizeKeywords(keywords []string) []string { if len(keywords) == 0 { diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go index 3ccf9ee9..06813229 100644 --- a/internal/memo/llm_extractor_test.go +++ b/internal/memo/llm_extractor_test.go @@ -196,8 +196,8 @@ func TestLLMExtractorExtractCancelledContext(t *testing.T) { } } -// TestLLMExtractorExtractUsesRecentNonToolMessages 验证只取最近 10 条非 tool 消息。 -func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) { +// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息而非固定 recent window。 +func TestLLMExtractorExtractUsesFullRunMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} extractor := NewLLMExtractor(generator, 10) @@ -219,19 +219,64 @@ func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) { if err != nil { t.Fatalf("Extract() error = %v", err) } - if len(generator.messages) != 10 { - t.Fatalf("len(generator.messages) = %d, want 10", len(generator.messages)) + if len(generator.messages) != 12 { + t.Fatalf("len(generator.messages) = %d, want 12", len(generator.messages)) } for _, message := range generator.messages { if message.Role == providertypes.RoleTool { t.Fatalf("unexpected tool message in extraction context: %#v", message) } } - if renderMemoParts(generator.messages[0].Parts) != "user-c" || - renderMemoParts(generator.messages[9].Parts) != "user-l" { - t.Fatalf("unexpected recent window: first=%q last=%q", + if renderMemoParts(generator.messages[0].Parts) != "user-a" || + renderMemoParts(generator.messages[11].Parts) != "user-l" { + t.Fatalf("unexpected run window: first=%q last=%q", renderMemoParts(generator.messages[0].Parts), - renderMemoParts(generator.messages[9].Parts)) + renderMemoParts(generator.messages[11].Parts)) + } +} + +func TestLLMExtractorExtractDecisionsIncludesExistingCandidates(t *testing.T) { + generator := &stubTextGenerator{ + response: `[{"action":"skip","ref":"user:u.md"},{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}]`, + } + extractor := NewLLMExtractor(generator, 10) + + decisions, err := extractor.ExtractDecisions( + context.Background(), + []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑相关测试。")}}, + }, + []ExtractionCandidate{ + { + Ref: "project:p.md", + Scope: ScopeProject, + Type: TypeFeedback, + Source: SourceAutoExtract, + Title: "测试策略", + Content: "用户要求修改后先跑测试。", + }, + }, + ) + if err != nil { + t.Fatalf("ExtractDecisions() error = %v", err) + } + if len(decisions) != 2 { + t.Fatalf("len(decisions) = %d, want 2", len(decisions)) + } + if decisions[0].Action != ExtractionActionSkip || decisions[0].Ref != "user:u.md" { + t.Fatalf("unexpected skip decision: %+v", decisions[0]) + } + if decisions[1].Action != ExtractionActionUpdate || decisions[1].Ref != "project:p.md" { + t.Fatalf("unexpected update decision: %+v", decisions[1]) + } + if decisions[1].Entry.Type != "" { + t.Fatalf("update decision should not require type, got %+v", decisions[1].Entry) + } + if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) { + t.Fatalf("prompt should include existing memory candidates, got %q", generator.prompt) + } + if !strings.Contains(generator.prompt, `action="update"`) || !strings.Contains(generator.prompt, `source="extractor_auto"`) { + t.Fatalf("prompt should describe semantic dedupe protocol, got %q", generator.prompt) } } diff --git a/internal/memo/service.go b/internal/memo/service.go index c7415c50..42551f26 100644 --- a/internal/memo/service.go +++ b/internal/memo/service.go @@ -68,6 +68,130 @@ func (s *Service) addAutoExtractIfAbsent(ctx context.Context, entry Entry) (bool return true, nil } +// autoExtractionCandidates 加载既有记忆快照,供模型在提取时做语义去重判断。 +func (s *Service) autoExtractionCandidates(ctx context.Context) ([]ExtractionCandidate, error) { + if s == nil { + return nil, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + candidates := make([]ExtractionCandidate, 0) + for _, scope := range supportedStorageScopes() { + index, err := s.loadIndexLocked(ctx, scope) + if err != nil { + return nil, err + } + for _, entry := range index.Entries { + topicFile := strings.TrimSpace(entry.TopicFile) + if topicFile == "" { + continue + } + topicContent, err := s.store.LoadTopic(ctx, scope, topicFile) + if err != nil { + continue + } + source, content := parseTopicSourceAndContent(topicContent) + candidates = append(candidates, ExtractionCandidate{ + Ref: scopedTopicKey(scope, topicFile), + Scope: scope, + Type: entry.Type, + Source: source, + Title: entry.Title, + Content: content, + }) + } + } + return candidates, nil +} + +// updateAutoExtractIfAllowed 按 ref 更新既有自动提取记忆,显式记忆不会被后台流程覆盖。 +func (s *Service) updateAutoExtractIfAllowed(ctx context.Context, ref string, next Entry) (bool, error) { + next.Title = NormalizeTitle(next.Title) + next.Content = strings.TrimSpace(next.Content) + next.Keywords = normalizeKeywords(next.Keywords) + if next.Title == "" { + return false, fmt.Errorf("memo: title is empty") + } + if next.Content == "" { + return false, fmt.Errorf("memo: content is empty") + } + if err := s.ensureAutoExtractIndex(ctx); err != nil { + return false, err + } + + scope, topicFile, ok := parseScopedTopicKey(ref) + if !ok { + return false, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + index, err := s.loadIndexLocked(ctx, scope) + if err != nil { + return false, err + } + working := cloneIndex(index) + + targetIndex := -1 + var current Entry + for idx, existing := range working.Entries { + if strings.TrimSpace(existing.TopicFile) == topicFile { + targetIndex = idx + current = existing + break + } + } + if targetIndex < 0 { + return false, nil + } + + topicContent, err := s.store.LoadTopic(ctx, scope, topicFile) + if err != nil { + return false, nil + } + source, _ := parseTopicSourceAndContent(topicContent) + if source != SourceAutoExtract { + return false, nil + } + + current.Title = next.Title + current.Content = next.Content + current.Keywords = next.Keywords + current.Source = SourceAutoExtract + current.TopicFile = topicFile + working.Entries[targetIndex] = current + working.UpdatedAt = time.Now() + + removedEntries := trimIndexEntries(working, s.config.MaxEntries, s.config.MaxIndexBytes) + if err := s.store.SaveTopic(ctx, scope, topicFile, RenderTopic(¤t)); err != nil { + return false, fmt.Errorf("memo: save topic: %w", err) + } + if err := s.store.SaveIndex(ctx, scope, working); err != nil { + return false, fmt.Errorf("memo: save index: %w", err) + } + for _, removed := range removedEntries { + if removedTopic := strings.TrimSpace(removed.TopicFile); removedTopic != "" { + _ = s.store.DeleteTopic(ctx, scope, removedTopic) + } + } + + if s.autoExtractIndexReady { + s.removeAutoExtractTopicLocked(scope, topicFile) + for _, removed := range removedEntries { + s.removeAutoExtractTopicLocked(scope, removed.TopicFile) + } + if indexContainsTopicFile(working, topicFile) { + s.trackAutoExtractEntryLocked(scope, current) + } + } + + s.invalidateCache() + return true, nil +} + // normalizeKeyword 统一关键词的空格与大小写处理。 func normalizeKeyword(keyword string) string { return strings.ToLower(strings.TrimSpace(keyword)) @@ -557,6 +681,23 @@ func indexContainsEntryID(index *Index, entryID string) bool { return false } +// indexContainsTopicFile 判断索引中是否仍保留指定 topic 文件。 +func indexContainsTopicFile(index *Index, topicFile string) bool { + if index == nil { + return false + } + topicFile = strings.TrimSpace(topicFile) + if topicFile == "" { + return false + } + for _, item := range index.Entries { + if strings.TrimSpace(item.TopicFile) == topicFile { + return true + } + } + return false +} + // scopesForQuery 将查询范围展开为实际存储分层列表。 func scopesForQuery(scope Scope) []Scope { switch NormalizeScope(scope) { @@ -584,3 +725,20 @@ func validateQueryScope(scope Scope) error { func scopedTopicKey(scope Scope, topicFile string) string { return string(scope) + ":" + strings.TrimSpace(topicFile) } + +// parseScopedTopicKey 解析自动提取语义去重协议中的 ref 字段。 +func parseScopedTopicKey(ref string) (Scope, string, bool) { + parts := strings.SplitN(strings.TrimSpace(ref), ":", 2) + if len(parts) != 2 { + return "", "", false + } + scope := Scope(strings.TrimSpace(parts[0])) + if err := validateStorageScope(scope); err != nil { + return "", "", false + } + topicFile := strings.TrimSpace(parts[1]) + if topicFile == "" { + return "", "", false + } + return scope, topicFile, true +} diff --git a/internal/memo/types.go b/internal/memo/types.go index 8572e0e4..0d1a62a0 100644 --- a/internal/memo/types.go +++ b/internal/memo/types.go @@ -89,6 +89,35 @@ type RecalledEntry struct { Content string } +// ExtractionAction 表示自动提取器对单条候选记忆的持久化决策。 +type ExtractionAction string + +const ( + // ExtractionActionCreate 表示新增一条记忆。 + ExtractionActionCreate ExtractionAction = "create" + // ExtractionActionUpdate 表示合并更新一条既有自动提取记忆。 + ExtractionActionUpdate ExtractionAction = "update" + // ExtractionActionSkip 表示跳过重复或不值得沉淀的内容。 + ExtractionActionSkip ExtractionAction = "skip" +) + +// ExtractionCandidate 表示提供给模型做语义去重的既有记忆快照。 +type ExtractionCandidate struct { + Ref string `json:"ref"` + Scope Scope `json:"scope"` + Type Type `json:"type"` + Source string `json:"source"` + Title string `json:"title"` + Content string `json:"content"` +} + +// ExtractionDecision 表示模型针对新旧记忆关系返回的结构化决策。 +type ExtractionDecision struct { + Action ExtractionAction + Ref string + Entry Entry +} + // Store 定义记忆持久化的最小抽象。 type Store interface { LoadIndex(ctx context.Context, scope Scope) (*Index, error) @@ -104,6 +133,15 @@ type Extractor interface { Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) } +// DecisionExtractor 定义带既有记忆快照的语义提取能力。 +type DecisionExtractor interface { + ExtractDecisions( + ctx context.Context, + messages []providertypes.Message, + existing []ExtractionCandidate, + ) ([]ExtractionDecision, error) +} + // TextGenerator 定义调用 LLM 生成文本的最小能力,用于记忆提取。 // 该接口隔离 provider 细节,避免 memo 包直接依赖 provider 基础设施。 type TextGenerator interface { diff --git a/internal/runtime/memo.go b/internal/runtime/memo.go index 477933d7..d2ce6202 100644 --- a/internal/runtime/memo.go +++ b/internal/runtime/memo.go @@ -19,6 +19,30 @@ func (s *Service) triggerMemoExtraction(sessionID string, messages []providertyp s.memoExtractor.Schedule(sessionID, cloneMessages(messages)) } +// runBoundaryMessagesForMemo 返回当前 run 边界内的消息切片,供自动记忆提取使用。 +func runBoundaryMessagesForMemo(state *runState) []providertypes.Message { + if state == nil { + return nil + } + state.mu.Lock() + defer state.mu.Unlock() + return cloneMessages(state.memoRunMessages) +} + +// appendMemoRunMessage 记录当前 run 内已成功写入 transcript 的消息,作为自动记忆提取边界。 +func appendMemoRunMessage(state *runState, message providertypes.Message) { + if state == nil || message.IsEmpty() { + return + } + cloned := cloneMessages([]providertypes.Message{message}) + if len(cloned) == 0 { + return + } + state.mu.Lock() + defer state.mu.Unlock() + state.memoRunMessages = append(state.memoRunMessages, cloned[0]) +} + // isSuccessfulRememberToolCall 判断工具调用是否成功完成显式记忆写入。 func isSuccessfulRememberToolCall(callName string, result tools.ToolResult, execErr error) bool { if execErr != nil || result.IsError { diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 6887ee55..4fb3d22d 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -472,7 +472,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { }) recordAcceptanceTerminal(&state, acceptanceDecision) s.emitRunScoped(ctx, EventAgentDone, &state, turnOutput.assistant) - s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun) + s.triggerMemoExtraction(state.session.ID, runBoundaryMessagesForMemo(&state), state.rememberedThisRun) return nil case acceptance.AcceptanceContinue: state.lastAcceptanceBlockedReason = strings.TrimSpace(acceptanceDecision.CompletionBlockedReason) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index f1db6601..a8f72459 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -145,7 +145,7 @@ type ProviderFactory interface { // MemoExtractor 定义 runtime 层调用记忆提取的最小能力。 // 通过接口注入避免 runtime 直接依赖 memo 子系统实现细节。 type MemoExtractor interface { - // Schedule 从消息中安排一次后台记忆提取,失败由实现方自行处理。 + // Schedule 从当前 run 边界内的消息安排一次后台记忆提取,失败由实现方自行处理。 Schedule(sessionID string, messages []providertypes.Message) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 9a38cc91..5e1ca56d 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -1180,6 +1180,52 @@ func TestServiceRunSchedulesMemoExtractionAfterFinalReply(t *testing.T) { } } +func TestServiceRunSchedulesMemoExtractionFromCurrentRunBoundary(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + store := newMemoryStore() + session := agentsession.New("existing") + session.ID = "session-existing-memo" + session.Messages = []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old user")}}, + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old assistant")}}, + } + store.sessions[session.ID] = cloneSession(session) + + scripted := &scriptedProvider{ + streams: [][]providertypes.StreamEvent{ + { + providertypes.NewTextDeltaStreamEvent("new final"), + providertypes.NewMessageDoneStreamEvent("stop", nil), + }, + }, + } + service := NewWithFactory(manager, tools.NewRegistry(), store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + memoExtractor := &stubScheduledMemoExtractor{} + service.SetMemoExtractor(memoExtractor) + + err := service.Run(context.Background(), UserInput{ + SessionID: session.ID, + RunID: "run-memo-boundary", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("new user")}, + }) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if len(memoExtractor.calls) != 1 { + t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls)) + } + messages := memoExtractor.calls[0].messages + if len(messages) != 2 { + t.Fatalf("scheduled messages = %#v, want current run user+assistant only", messages) + } + if renderPartsForVerification(messages[0].Parts) != "new user" || + renderPartsForVerification(messages[1].Parts) != "new final" { + t.Fatalf("scheduled messages crossed run boundary: %#v", messages) + } +} + func TestServiceRunSkipsAutoMemoExtractionAfterRememberTool(t *testing.T) { t.Parallel() @@ -5555,6 +5601,8 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) { } service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder) + memoExtractor := &stubScheduledMemoExtractor{} + service.SetMemoExtractor(memoExtractor) service.compactRunner = &stubCompactRunner{ result: contextcompact.Result{ Messages: []providertypes.Message{ @@ -5617,6 +5665,16 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) { if renderPartsForTest(saved.Messages[2].Parts) != "recovered" { t.Fatalf("expected final assistant reply %q, got %q", "recovered", renderPartsForTest(saved.Messages[2].Parts)) } + if len(memoExtractor.calls) != 1 { + t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls)) + } + memoMessages := memoExtractor.calls[0].messages + if len(memoMessages) != 2 { + t.Fatalf("memo messages = %+v, want current run user+assistant only after compact", memoMessages) + } + if renderPartsForTest(memoMessages[0].Parts) != "continue" || renderPartsForTest(memoMessages[1].Parts) != "recovered" { + t.Fatalf("memo messages crossed compact boundary: %+v", memoMessages) + } events := collectRuntimeEvents(service.Events()) assertEventSequence(t, events, []EventType{ diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index 26d3d9bb..9f00db6f 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -29,6 +29,7 @@ func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, }); err != nil { return err } + appendMemoRunMessage(state, message) s.emitRunScoped(ctx, EventUserMessage, state, message) return nil } @@ -87,7 +88,7 @@ func (s *Service) appendAssistantMessageOnlyAndSave( } state.session.Messages = append(state.session.Messages, assistant) state.touchSession() - return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ + if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ SessionID: state.session.ID, Messages: []providertypes.Message{assistant}, UpdatedAt: state.session.UpdatedAt, @@ -95,7 +96,11 @@ func (s *Service) appendAssistantMessageOnlyAndSave( Model: state.session.Model, Workdir: state.session.Workdir, HasUnknownUsage: state.session.HasUnknownUsage, - }) + }); err != nil { + return err + } + appendMemoRunMessage(state, assistant) + return nil } // appendSystemMessageAndSave 将系统提醒消息追加到会话并持久化。 @@ -111,7 +116,7 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat } state.session.Messages = append(state.session.Messages, message) state.touchSession() - return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ + if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ SessionID: state.session.ID, Messages: []providertypes.Message{message}, UpdatedAt: state.session.UpdatedAt, @@ -119,7 +124,11 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat Model: state.session.Model, Workdir: state.session.Workdir, HasUnknownUsage: state.session.HasUnknownUsage, - }) + }); err != nil { + return err + } + appendMemoRunMessage(state, message) + return nil } // appendToolMessageAndSave 将工具原始结果写回会话,持久化时仅追加一条 tool message。 @@ -143,7 +152,11 @@ func (s *Service) appendToolMessageAndSave( HasUnknownUsage: state.session.HasUnknownUsage, } state.mu.Unlock() - return s.sessionStore.AppendMessages(ctx, input) + if err := s.sessionStore.AppendMessages(ctx, input); err != nil { + return err + } + appendMemoRunMessage(state, toolMessage) + return nil } // normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。 diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 2440bddf..277d9c6e 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -4,6 +4,7 @@ import ( "sync" "time" + providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/controlplane" "neo-code/internal/runtime/decider" runtimefacts "neo-code/internal/runtime/facts" @@ -20,6 +21,7 @@ type runState struct { effectiveWorkdir string compactCount int reactiveCompactAttempts int + memoRunMessages []providertypes.Message rememberedThisRun bool planningEnabled bool taskID string From 1b19d3047b7b9131229eb73bb768e4c30468def9 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sun, 10 May 2026 10:47:56 +0800 Subject: [PATCH 2/5] =?UTF-8?q?fix:web=E7=AB=AFslash=E6=8C=87=E4=BB=A4?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/chat/ChatInput.test.tsx | 69 ++++++++++++++++++++++ web/src/components/chat/ChatInput.tsx | 25 ++++---- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index 8c610e7b..bf39b1b2 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -25,6 +25,12 @@ vi.mock('./ModelSelector', () => ({ default: () =>
, })) +async function submitSlashCommand(command: string) { + const textarea = screen.getByRole('textbox') as HTMLTextAreaElement + fireEvent.change(textarea, { target: { value: command } }) + fireEvent.keyDown(textarea, { key: 'Enter' }) +} + describe('ChatInput', () => { beforeEach(() => { vi.clearAllMocks() @@ -134,4 +140,67 @@ describe('ChatInput', () => { expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) + it('executes /memo without session id and shows payload.Content', async () => { + mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ + payload: { + Content: 'User Memo:\n- [user] coding preference', + }, + }) + render() + + await submitSlashCommand('/memo') + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).toHaveBeenCalledWith('', '', 'memo_list', {}) + }) + await waitFor(() => { + expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content.includes('coding preference'))).toBe(true) + }) + }) + + it('uses fallback text when memo payload has no content field', async () => { + mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ payload: {} }) + render() + + await submitSlashCommand('/memo') + + await waitFor(() => { + expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content === 'Memo query complete')).toBe(true) + }) + }) + + it('executes /remember and /forget without session id', async () => { + mockGatewayAPI.executeSystemTool + .mockResolvedValueOnce({ payload: { Content: 'Memory saved: [user] keep tests strict' } }) + .mockResolvedValueOnce({ payload: { Content: 'Removed 1 memo(s) matching \"strict\".' } }) + render() + + await submitSlashCommand('/remember keep tests strict') + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(1, '', '', 'memo_remember', { + type: 'user', + title: 'keep tests strict', + content: 'keep tests strict', + }) + }) + + await submitSlashCommand('/forget strict') + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(2, '', '', 'memo_remove', { + keyword: 'strict', + scope: 'all', + }) + }) + }) + + it('keeps argument validation for /remember and /forget', async () => { + render() + + await submitSlashCommand('/remember') + await submitSlashCommand('/forget') + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled() + }) + }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index ff4258e4..d5a4bb2e 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -57,6 +57,13 @@ function buildSlashHelpText(commands: AnySlashCommand[]): string { return ['可用命令:', ...lines].join('\n') } +/** 统一提取系统工具返回文本,兼容 payload.content 与 payload.Content。 */ +function extractSystemToolContent(result: unknown, fallback: string): string { + const payload = (result as { payload?: { content?: string; Content?: string } } | null)?.payload + const content = payload?.content ?? payload?.Content + return content || fallback +} + export default function ChatInput() { const gatewayAPI = useGatewayAPI() const text = useComposerStore((state) => state.composerText) @@ -149,13 +156,9 @@ export default function ChatInput() { return true } case '/memo': { - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Send a message first to start a session', 'error') - return true - } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {}) - addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete') + addSystemMessage(extractSystemToolContent(result, 'Memo query complete')) } catch (err) { console.error('Memo list failed:', err) useUIStore.getState().showToast('Failed to query memo', 'error') @@ -167,17 +170,13 @@ export default function ChatInput() { useUIStore.getState().showToast('Usage: /remember ', 'error') return true } - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Send a message first to start a session', 'error') - return true - } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { type: 'user', title: argument, content: argument, }) - addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved') + addSystemMessage(extractSystemToolContent(result, 'Memo saved')) } catch (err) { console.error('Remember failed:', err) useUIStore.getState().showToast('Failed to save memo', 'error') @@ -189,16 +188,12 @@ export default function ChatInput() { useUIStore.getState().showToast('Usage: /forget ', 'error') return true } - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Send a message first to start a session', 'error') - return true - } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { keyword: argument, scope: 'all', }) - addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted') + addSystemMessage(extractSystemToolContent(result, 'Memo deleted')) } catch (err) { console.error('Forget failed:', err) useUIStore.getState().showToast('Failed to delete memo', 'error') From 5cde1d19f005baa0d52abe43ae30660260b466ac Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sun, 10 May 2026 10:36:56 +0800 Subject: [PATCH 3/5] =?UTF-8?q?fix(memo):memo=E5=88=9D=E6=AD=A5=E4=BF=AE?= =?UTF-8?q?=E7=90=86,run=E6=88=AA=E5=8F=96=EF=BC=8Cllm=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E6=97=B6=E5=8E=BB=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap_test.go | 42 +++++++ internal/context/projection.go | 36 ++++++ internal/context/projection_test.go | 44 ++++++++ internal/memo/auto_extractor.go | 64 +++++++++++ internal/memo/auto_extractor_test.go | 144 ++++++++++++++++++++++++ internal/memo/llm_extractor.go | 123 ++++++++++++++++++--- internal/memo/llm_extractor_test.go | 61 +++++++++-- internal/memo/service.go | 158 +++++++++++++++++++++++++++ internal/memo/types.go | 38 +++++++ internal/runtime/memo.go | 24 ++++ internal/runtime/run.go | 2 +- internal/runtime/runtime.go | 2 +- internal/runtime/runtime_test.go | 58 ++++++++++ internal/runtime/session_mutation.go | 23 +++- internal/runtime/state.go | 2 + 15 files changed, 793 insertions(+), 28 deletions(-) diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index 3760dded..917823cb 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -1675,6 +1675,48 @@ func TestNewMemoExtractorAdapterBuildsProviderSafeMemoWindow(t *testing.T) { } } +func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) { + t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token") + cfg := config.StaticDefaults().Clone() + cfg.SelectedProvider = config.OpenAIName + cfg.Memo.ExtractRecentMessages = 3 + manager := config.NewManager(config.NewLoader("", &cfg)) + + providerStub := &stubMemoProvider{ + generate: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + if len(req.Messages) != 12 { + t.Fatalf("unexpected memo window length %d, want full run: %+v", len(req.Messages), req.Messages) + } + events <- providertypes.NewTextDeltaStreamEvent(`[]`) + events <- providertypes.NewMessageDoneStreamEvent("stop", nil) + return nil + }, + } + factory := &stubMemoProviderFactory{provider: providerStub} + scheduler := &stubMemoExtractorScheduler{} + extractor := newMemoExtractorAdapter(factory, manager, scheduler) + + inputMessages := make([]providertypes.Message, 0, 12) + for index := 0; index < 12; index++ { + inputMessages = append(inputMessages, providertypes.Message{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(fmt.Sprintf("message-%02d", index))}, + }) + } + extractor.Schedule("session-1", inputMessages) + if !scheduler.called || scheduler.extractor == nil { + t.Fatalf("expected scheduler to receive extractor") + } + + _, err := scheduler.extractor.Extract(context.Background(), inputMessages) + if err != nil { + t.Fatalf("extractor.Extract() error = %v", err) + } + if !factory.called { + t.Fatalf("expected provider factory Build to be called") + } +} + func TestNewMemoExtractorAdapterKeepsScheduledConfigSnapshot(t *testing.T) { t.Setenv(config.OpenAIDefaultAPIKeyEnv, "openai-token") t.Setenv(config.QiniuDefaultAPIKeyEnv, "qiniu-token") diff --git a/internal/context/projection.go b/internal/context/projection.go index ac7da4b3..298f7e8c 100644 --- a/internal/context/projection.go +++ b/internal/context/projection.go @@ -79,6 +79,42 @@ func BuildRecentMessagesForModel(messages []providertypes.Message, limit int) [] return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) } +// BuildMemoExtractionMessagesForModel 构造完整 run 的 provider-safe 记忆提取上下文。 +func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []providertypes.Message { + if len(messages) == 0 { + return nil + } + + keep := make([]bool, len(messages)) + for index := 0; index < len(messages); index++ { + message := messages[index] + if message.Role == providertypes.RoleTool { + continue + } + + if message.Role == providertypes.RoleAssistant && len(message.ToolCalls) > 0 { + for _, spanIndex := range matchedToolCallSpan(messages, index) { + keep[spanIndex] = true + } + continue + } + + keep[index] = true + } + + selected := make([]providertypes.Message, 0, len(messages)) + for index, message := range messages { + if keep[index] { + selected = append(selected, message) + } + } + if len(selected) == 0 { + return nil + } + + return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) +} + // matchedToolCallSpan 返回 assistant tool call 与其完整 tool 响应组成的合法窗口下标集合。 func matchedToolCallSpan(messages []providertypes.Message, assistantIndex int) []int { if assistantIndex < 0 || assistantIndex >= len(messages) { diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go index 372e3c3f..3b0c932a 100644 --- a/internal/context/projection_test.go +++ b/internal/context/projection_test.go @@ -140,6 +140,50 @@ func TestBuildRecentMessagesForModelKeepsOnlyRecentValidAnchors(t *testing.T) { } } +func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T) { + t.Parallel() + + messages := []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}, + {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan")}}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")}, + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-missing", Name: "bash", Arguments: `{}`}, + }, + }, + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("last")}}, + } + + projected := BuildMemoExtractionMessagesForModel(messages) + if len(projected) != 4 { + t.Fatalf("len(projected) = %d, want 4: %+v", len(projected), projected) + } + if renderDisplayParts(projected[0].Parts) != "first" || renderDisplayParts(projected[3].Parts) != "last" { + t.Fatalf("expected full run user messages to remain, got %+v", projected) + } + if projected[1].Role != providertypes.RoleAssistant || len(projected[1].ToolCalls) != 1 { + t.Fatalf("expected complete assistant tool span, got %+v", projected[1]) + } + if projected[2].Role != providertypes.RoleTool || + !strings.Contains(renderDisplayParts(projected[2].Parts), "tool result") || + projected[2].ToolMetadata != nil { + t.Fatalf("expected projected tool result, got %+v", projected[2]) + } +} + func TestBuildRecentMessagesForModelRespectsAbsoluteMessageBudget(t *testing.T) { t.Parallel() diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go index aaadfec9..a1ccb43b 100644 --- a/internal/memo/auto_extractor.go +++ b/internal/memo/auto_extractor.go @@ -225,6 +225,24 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout) defer cancel() + if decisionExtractor, ok := extractor.(DecisionExtractor); ok { + existing, err := a.svc.autoExtractionCandidates(ctx) + if err != nil { + a.logError("memo: auto extract load candidates failed: %v", err) + return false + } + decisions, err := decisionExtractor.ExtractDecisions(ctx, messages, existing) + if err != nil { + if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { + a.logError("memo: auto extract skipped (protocol_mismatch): %v", err) + return true + } + a.logError("memo: auto extract failed: %v", err) + return false + } + return a.applyExtractionDecisions(ctx, decisions) + } + entries, err := extractor.Extract(ctx, messages) if err != nil { if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { @@ -265,6 +283,52 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider return succeeded } +// applyExtractionDecisions 应用模型返回的 create/update/skip 决策,并保留本地精确去重兜底。 +func (a *AutoExtractor) applyExtractionDecisions(ctx context.Context, decisions []ExtractionDecision) bool { + if len(decisions) == 0 { + return true + } + + seenCreates := make(map[string]struct{}, len(decisions)) + succeeded := true + for _, decision := range decisions { + switch decision.Action { + case ExtractionActionCreate: + entry := decision.Entry + entry.Source = SourceAutoExtract + key := autoExtractDedupKey(entry) + if key == "" { + continue + } + if _, exists := seenCreates[key]; exists { + continue + } + added, err := a.svc.addAutoExtractIfAbsent(ctx, entry) + if err != nil { + a.logError("memo: auto extract add failed: %v", err) + succeeded = false + continue + } + seenCreates[key] = struct{}{} + if !added { + continue + } + case ExtractionActionUpdate: + entry := decision.Entry + entry.Source = SourceAutoExtract + _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, entry) + if err != nil { + a.logError("memo: auto extract update failed: %v", err) + succeeded = false + continue + } + case ExtractionActionSkip: + continue + } + } + return succeeded +} + // autoExtractDedupKey 生成自动提取条目的精确去重键。 func autoExtractDedupKey(entry Entry) string { title := NormalizeTitle(entry.Title) diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go index 2f188d07..6c4e5835 100644 --- a/internal/memo/auto_extractor_test.go +++ b/internal/memo/auto_extractor_test.go @@ -38,6 +38,33 @@ func (s *stubMemoExtractor) Calls() int { return s.callCount } +type stubDecisionMemoExtractor struct { + mu sync.Mutex + callCount int + candidates []ExtractionCandidate + decisions []ExtractionDecision + err error +} + +func (s *stubDecisionMemoExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { + return nil, nil +} + +func (s *stubDecisionMemoExtractor) ExtractDecisions( + ctx context.Context, + messages []providertypes.Message, + existing []ExtractionCandidate, +) ([]ExtractionDecision, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.callCount++ + s.candidates = append([]ExtractionCandidate(nil), existing...) + if s.err != nil { + return nil, s.err + } + return append([]ExtractionDecision(nil), s.decisions...), nil +} + func newAutoExtractorTestService(t *testing.T) *Service { t.Helper() store := NewFileStore(t.TempDir(), t.TempDir()) @@ -214,6 +241,123 @@ func TestAutoExtractorSuppressesExactDuplicates(t *testing.T) { }) } +func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing.T) { + svc := newAutoExtractorTestService(t) + if err := svc.Add(context.Background(), Entry{ + Type: TypeFeedback, + Title: "测试策略", + Content: "用户要求修改后先跑测试。", + Source: SourceAutoExtract, + }); err != nil { + t.Fatalf("seed auto Add() error = %v", err) + } + if err := svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "语言偏好", + Content: "用户手动保存偏好中文回复。", + Source: SourceUserManual, + }); err != nil { + t.Fatalf("seed manual Add() error = %v", err) + } + + candidates, err := svc.autoExtractionCandidates(context.Background()) + if err != nil { + t.Fatalf("autoExtractionCandidates() error = %v", err) + } + var autoRef, manualRef string + for _, candidate := range candidates { + switch candidate.Source { + case SourceAutoExtract: + autoRef = candidate.Ref + case SourceUserManual: + manualRef = candidate.Ref + } + } + if autoRef == "" || manualRef == "" { + t.Fatalf("expected refs, auto=%q manual=%q candidates=%+v", autoRef, manualRef, candidates) + } + + extractor := &stubDecisionMemoExtractor{ + decisions: []ExtractionDecision{ + { + Action: ExtractionActionUpdate, + Ref: autoRef, + Entry: Entry{ + Title: "测试策略", + Content: "用户要求修改后先跑相关测试。", + Keywords: []string{"test"}, + }, + }, + { + Action: ExtractionActionUpdate, + Ref: manualRef, + Entry: Entry{ + Title: "语言偏好", + Content: "不应覆盖手动记忆。", + }, + }, + }, + } + auto := NewAutoExtractor(extractor, svc, time.Second) + auto.debounce = 5 * time.Millisecond + auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) + + auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("更新测试策略")}}}) + waitFor(t, time.Second, func() bool { + extractor.mu.Lock() + defer extractor.mu.Unlock() + return extractor.callCount == 1 + }) + + recall, err := svc.Recall(context.Background(), "测试策略", ScopeProject) + if err != nil { + t.Fatalf("Recall(auto) error = %v", err) + } + if len(recall) != 1 || !strings.Contains(recall[0].Content, "相关测试") { + t.Fatalf("expected auto memory to be updated, got %+v", recall) + } + manualRecall, err := svc.Recall(context.Background(), "语言偏好", ScopeUser) + if err != nil { + t.Fatalf("Recall(manual) error = %v", err) + } + if len(manualRecall) != 1 || strings.Contains(manualRecall[0].Content, "不应覆盖") { + t.Fatalf("manual memory should not be overwritten, got %+v", manualRecall) + } + if len(extractor.candidates) != 2 { + t.Fatalf("expected existing candidates to be provided, got %+v", extractor.candidates) + } +} + +func TestAutoExtractorSemanticCreateStillUsesExactDedup(t *testing.T) { + svc := newAutoExtractorTestService(t) + if err := svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "中文回复", + Content: "用户偏好中文回复。", + Source: SourceAutoExtract, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + extractor := &stubDecisionMemoExtractor{ + decisions: []ExtractionDecision{ + {Action: ExtractionActionCreate, Entry: Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}}, + {Action: ExtractionActionCreate, Entry: Entry{Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}}, + }, + } + auto := NewAutoExtractor(extractor, svc, time.Second) + auto.debounce = 5 * time.Millisecond + auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) + + auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("dedupe")}}}) + waitFor(t, time.Second, func() bool { + entries, err := svc.List(context.Background(), ScopeAll) + return err == nil && len(entries) == 2 + }) +} + func TestAutoExtractorUsesTimeoutContext(t *testing.T) { svc := newAutoExtractorTestService(t) extractor := &stubMemoExtractor{ diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go index 9d039f44..621446de 100644 --- a/internal/memo/llm_extractor.go +++ b/internal/memo/llm_extractor.go @@ -27,6 +27,8 @@ type LLMExtractor struct { } type extractedEntry struct { + Action string `json:"action"` + Ref string `json:"ref"` Type string `json:"type"` Title string `json:"title"` Content string `json:"content"` @@ -45,8 +47,27 @@ func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtrac } } -// Extract 从最近对话中提取可跨会话持久化的记忆条目。 +// Extract 从当前 run 对话中提取可跨会话持久化的新增记忆条目。 func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { + decisions, err := e.ExtractDecisions(ctx, messages, nil) + if err != nil { + return nil, err + } + entries := make([]Entry, 0, len(decisions)) + for _, decision := range decisions { + if decision.Action == ExtractionActionCreate { + entries = append(entries, decision.Entry) + } + } + return entries, nil +} + +// ExtractDecisions 从当前 run 对话中提取记忆,并结合既有记忆输出新增、合并或跳过决策。 +func (e *LLMExtractor) ExtractDecisions( + ctx context.Context, + messages []providertypes.Message, + existing []ExtractionCandidate, +) ([]ExtractionDecision, error) { if err := ctx.Err(); err != nil { return nil, err } @@ -54,12 +75,12 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes return nil, errors.New("memo: text generator is nil") } - recent := agentcontext.BuildRecentMessagesForModel(messages, e.recentMessageLimit) - if len(recent) == 0 || !containsUserMessage(recent) { + runMessages := agentcontext.BuildMemoExtractionMessagesForModel(messages) + if len(runMessages) == 0 || !containsUserMessage(runMessages) { return nil, nil } - response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now()), recent) + response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now(), existing), runMessages) if err != nil { return nil, err } @@ -77,23 +98,29 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes return nil, fmt.Errorf("memo: parse extraction response: %w", err) } - entries := make([]Entry, 0, len(extracted)) + decisions := make([]ExtractionDecision, 0, len(extracted)) for _, item := range extracted { - entry, ok := toMemoEntry(item) + decision, ok := toExtractionDecision(item) if !ok { continue } - entries = append(entries, entry) + decisions = append(decisions, decision) } - return entries, nil + return decisions, nil } // buildExtractionPrompt 构造记忆提取专用的 system prompt。 -func buildExtractionPrompt(now time.Time) string { +func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string { currentDate := now.In(time.Local).Format("2006-01-02") + existingJSON := "[]" + if len(existing) > 0 { + if data, err := json.Marshal(existing); err == nil { + existingJSON = string(data) + } + } return strings.TrimSpace(fmt.Sprintf(` 你是一个记忆提取助手(memory extraction assistant)。 -分析最近对话中值得跨会话持久记住的信息,并返回严格 JSON 数组。 +分析当前 run 对话中值得跨会话持久记住的信息,并结合既有记忆完成语义去重,返回严格 JSON 数组。 当前本地日期:%s 如果对话中出现“今天、明天、下周二”等相对日期,必须先转换为绝对日期再写入 content。 @@ -109,11 +136,18 @@ func buildExtractionPrompt(now time.Time) string { 2. 不要提取通用编程知识、代码结构、文件路径、Git 历史。 3. 每条记忆必须具体、可操作。 4. 没有值得记住的信息时,返回 []。 -5. 输出必须是 JSON 数组,不要输出任何额外解释。 +5. 如果新信息与既有记忆语义相同或只是轻微改写,输出 action="skip"。 +6. 如果新信息能补充或修正既有 source="extractor_auto" 的记忆,输出 action="update" 并填写目标 ref。 +7. 不允许 update source 不是 "extractor_auto" 的既有记忆;这类相近内容只能 skip。 +8. 如果是全新的可持久化信息,输出 action="create"。 +9. 输出必须是 JSON 数组,不要输出任何额外解释。 + +既有记忆候选(JSON): +%s 输出格式: -[{"type":"user","title":"...","content":"...","keywords":["..."]}] -`, currentDate)) +[{"action":"create","type":"user","title":"...","content":"...","keywords":["..."]},{"action":"update","ref":"project:p.md","title":"...","content":"...","keywords":["..."]},{"action":"skip","ref":"user:u.md"}] +`, currentDate, existingJSON)) } // containsUserMessage 检查待提取消息中是否包含用户输入。 @@ -148,6 +182,69 @@ func toMemoEntry(item extractedEntry) (Entry, bool) { }, true } +// toExtractionDecision 将 LLM 输出收敛为自动提取持久化决策。 +func toExtractionDecision(item extractedEntry) (ExtractionDecision, bool) { + action := parseExtractionAction(item.Action) + if action == "" { + return ExtractionDecision{}, false + } + if action == ExtractionActionSkip { + return ExtractionDecision{ + Action: action, + Ref: strings.TrimSpace(item.Ref), + }, true + } + + if action == ExtractionActionUpdate { + ref := strings.TrimSpace(item.Ref) + if ref == "" { + return ExtractionDecision{}, false + } + entry, ok := toMemoUpdateEntry(item) + if !ok { + return ExtractionDecision{}, false + } + return ExtractionDecision{Action: action, Ref: ref, Entry: entry}, true + } + + entry, ok := toMemoEntry(item) + if !ok { + return ExtractionDecision{}, false + } + return ExtractionDecision{Action: action, Entry: entry}, true +} + +// toMemoUpdateEntry 将 update 决策中的可变字段收敛为 Entry 片段。 +func toMemoUpdateEntry(item extractedEntry) (Entry, bool) { + title := NormalizeTitle(item.Title) + content := strings.TrimSpace(item.Content) + if title == "" || content == "" { + return Entry{}, false + } + return Entry{ + Title: title, + Content: content, + Keywords: normalizeKeywords(item.Keywords), + Source: SourceAutoExtract, + }, true +} + +// parseExtractionAction 解析模型决策动作,并兼容旧格式中缺省 action 的 create 输出。 +func parseExtractionAction(action string) ExtractionAction { + switch ExtractionAction(strings.ToLower(strings.TrimSpace(action))) { + case "": + return ExtractionActionCreate + case ExtractionActionCreate: + return ExtractionActionCreate + case ExtractionActionUpdate: + return ExtractionActionUpdate + case ExtractionActionSkip: + return ExtractionActionSkip + default: + return "" + } +} + // normalizeKeywords 规范化关键词列表,移除空值和重复值。 func normalizeKeywords(keywords []string) []string { if len(keywords) == 0 { diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go index 3ccf9ee9..06813229 100644 --- a/internal/memo/llm_extractor_test.go +++ b/internal/memo/llm_extractor_test.go @@ -196,8 +196,8 @@ func TestLLMExtractorExtractCancelledContext(t *testing.T) { } } -// TestLLMExtractorExtractUsesRecentNonToolMessages 验证只取最近 10 条非 tool 消息。 -func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) { +// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息而非固定 recent window。 +func TestLLMExtractorExtractUsesFullRunMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} extractor := NewLLMExtractor(generator, 10) @@ -219,19 +219,64 @@ func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) { if err != nil { t.Fatalf("Extract() error = %v", err) } - if len(generator.messages) != 10 { - t.Fatalf("len(generator.messages) = %d, want 10", len(generator.messages)) + if len(generator.messages) != 12 { + t.Fatalf("len(generator.messages) = %d, want 12", len(generator.messages)) } for _, message := range generator.messages { if message.Role == providertypes.RoleTool { t.Fatalf("unexpected tool message in extraction context: %#v", message) } } - if renderMemoParts(generator.messages[0].Parts) != "user-c" || - renderMemoParts(generator.messages[9].Parts) != "user-l" { - t.Fatalf("unexpected recent window: first=%q last=%q", + if renderMemoParts(generator.messages[0].Parts) != "user-a" || + renderMemoParts(generator.messages[11].Parts) != "user-l" { + t.Fatalf("unexpected run window: first=%q last=%q", renderMemoParts(generator.messages[0].Parts), - renderMemoParts(generator.messages[9].Parts)) + renderMemoParts(generator.messages[11].Parts)) + } +} + +func TestLLMExtractorExtractDecisionsIncludesExistingCandidates(t *testing.T) { + generator := &stubTextGenerator{ + response: `[{"action":"skip","ref":"user:u.md"},{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}]`, + } + extractor := NewLLMExtractor(generator, 10) + + decisions, err := extractor.ExtractDecisions( + context.Background(), + []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑相关测试。")}}, + }, + []ExtractionCandidate{ + { + Ref: "project:p.md", + Scope: ScopeProject, + Type: TypeFeedback, + Source: SourceAutoExtract, + Title: "测试策略", + Content: "用户要求修改后先跑测试。", + }, + }, + ) + if err != nil { + t.Fatalf("ExtractDecisions() error = %v", err) + } + if len(decisions) != 2 { + t.Fatalf("len(decisions) = %d, want 2", len(decisions)) + } + if decisions[0].Action != ExtractionActionSkip || decisions[0].Ref != "user:u.md" { + t.Fatalf("unexpected skip decision: %+v", decisions[0]) + } + if decisions[1].Action != ExtractionActionUpdate || decisions[1].Ref != "project:p.md" { + t.Fatalf("unexpected update decision: %+v", decisions[1]) + } + if decisions[1].Entry.Type != "" { + t.Fatalf("update decision should not require type, got %+v", decisions[1].Entry) + } + if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) { + t.Fatalf("prompt should include existing memory candidates, got %q", generator.prompt) + } + if !strings.Contains(generator.prompt, `action="update"`) || !strings.Contains(generator.prompt, `source="extractor_auto"`) { + t.Fatalf("prompt should describe semantic dedupe protocol, got %q", generator.prompt) } } diff --git a/internal/memo/service.go b/internal/memo/service.go index c7415c50..42551f26 100644 --- a/internal/memo/service.go +++ b/internal/memo/service.go @@ -68,6 +68,130 @@ func (s *Service) addAutoExtractIfAbsent(ctx context.Context, entry Entry) (bool return true, nil } +// autoExtractionCandidates 加载既有记忆快照,供模型在提取时做语义去重判断。 +func (s *Service) autoExtractionCandidates(ctx context.Context) ([]ExtractionCandidate, error) { + if s == nil { + return nil, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + candidates := make([]ExtractionCandidate, 0) + for _, scope := range supportedStorageScopes() { + index, err := s.loadIndexLocked(ctx, scope) + if err != nil { + return nil, err + } + for _, entry := range index.Entries { + topicFile := strings.TrimSpace(entry.TopicFile) + if topicFile == "" { + continue + } + topicContent, err := s.store.LoadTopic(ctx, scope, topicFile) + if err != nil { + continue + } + source, content := parseTopicSourceAndContent(topicContent) + candidates = append(candidates, ExtractionCandidate{ + Ref: scopedTopicKey(scope, topicFile), + Scope: scope, + Type: entry.Type, + Source: source, + Title: entry.Title, + Content: content, + }) + } + } + return candidates, nil +} + +// updateAutoExtractIfAllowed 按 ref 更新既有自动提取记忆,显式记忆不会被后台流程覆盖。 +func (s *Service) updateAutoExtractIfAllowed(ctx context.Context, ref string, next Entry) (bool, error) { + next.Title = NormalizeTitle(next.Title) + next.Content = strings.TrimSpace(next.Content) + next.Keywords = normalizeKeywords(next.Keywords) + if next.Title == "" { + return false, fmt.Errorf("memo: title is empty") + } + if next.Content == "" { + return false, fmt.Errorf("memo: content is empty") + } + if err := s.ensureAutoExtractIndex(ctx); err != nil { + return false, err + } + + scope, topicFile, ok := parseScopedTopicKey(ref) + if !ok { + return false, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + index, err := s.loadIndexLocked(ctx, scope) + if err != nil { + return false, err + } + working := cloneIndex(index) + + targetIndex := -1 + var current Entry + for idx, existing := range working.Entries { + if strings.TrimSpace(existing.TopicFile) == topicFile { + targetIndex = idx + current = existing + break + } + } + if targetIndex < 0 { + return false, nil + } + + topicContent, err := s.store.LoadTopic(ctx, scope, topicFile) + if err != nil { + return false, nil + } + source, _ := parseTopicSourceAndContent(topicContent) + if source != SourceAutoExtract { + return false, nil + } + + current.Title = next.Title + current.Content = next.Content + current.Keywords = next.Keywords + current.Source = SourceAutoExtract + current.TopicFile = topicFile + working.Entries[targetIndex] = current + working.UpdatedAt = time.Now() + + removedEntries := trimIndexEntries(working, s.config.MaxEntries, s.config.MaxIndexBytes) + if err := s.store.SaveTopic(ctx, scope, topicFile, RenderTopic(¤t)); err != nil { + return false, fmt.Errorf("memo: save topic: %w", err) + } + if err := s.store.SaveIndex(ctx, scope, working); err != nil { + return false, fmt.Errorf("memo: save index: %w", err) + } + for _, removed := range removedEntries { + if removedTopic := strings.TrimSpace(removed.TopicFile); removedTopic != "" { + _ = s.store.DeleteTopic(ctx, scope, removedTopic) + } + } + + if s.autoExtractIndexReady { + s.removeAutoExtractTopicLocked(scope, topicFile) + for _, removed := range removedEntries { + s.removeAutoExtractTopicLocked(scope, removed.TopicFile) + } + if indexContainsTopicFile(working, topicFile) { + s.trackAutoExtractEntryLocked(scope, current) + } + } + + s.invalidateCache() + return true, nil +} + // normalizeKeyword 统一关键词的空格与大小写处理。 func normalizeKeyword(keyword string) string { return strings.ToLower(strings.TrimSpace(keyword)) @@ -557,6 +681,23 @@ func indexContainsEntryID(index *Index, entryID string) bool { return false } +// indexContainsTopicFile 判断索引中是否仍保留指定 topic 文件。 +func indexContainsTopicFile(index *Index, topicFile string) bool { + if index == nil { + return false + } + topicFile = strings.TrimSpace(topicFile) + if topicFile == "" { + return false + } + for _, item := range index.Entries { + if strings.TrimSpace(item.TopicFile) == topicFile { + return true + } + } + return false +} + // scopesForQuery 将查询范围展开为实际存储分层列表。 func scopesForQuery(scope Scope) []Scope { switch NormalizeScope(scope) { @@ -584,3 +725,20 @@ func validateQueryScope(scope Scope) error { func scopedTopicKey(scope Scope, topicFile string) string { return string(scope) + ":" + strings.TrimSpace(topicFile) } + +// parseScopedTopicKey 解析自动提取语义去重协议中的 ref 字段。 +func parseScopedTopicKey(ref string) (Scope, string, bool) { + parts := strings.SplitN(strings.TrimSpace(ref), ":", 2) + if len(parts) != 2 { + return "", "", false + } + scope := Scope(strings.TrimSpace(parts[0])) + if err := validateStorageScope(scope); err != nil { + return "", "", false + } + topicFile := strings.TrimSpace(parts[1]) + if topicFile == "" { + return "", "", false + } + return scope, topicFile, true +} diff --git a/internal/memo/types.go b/internal/memo/types.go index 8572e0e4..0d1a62a0 100644 --- a/internal/memo/types.go +++ b/internal/memo/types.go @@ -89,6 +89,35 @@ type RecalledEntry struct { Content string } +// ExtractionAction 表示自动提取器对单条候选记忆的持久化决策。 +type ExtractionAction string + +const ( + // ExtractionActionCreate 表示新增一条记忆。 + ExtractionActionCreate ExtractionAction = "create" + // ExtractionActionUpdate 表示合并更新一条既有自动提取记忆。 + ExtractionActionUpdate ExtractionAction = "update" + // ExtractionActionSkip 表示跳过重复或不值得沉淀的内容。 + ExtractionActionSkip ExtractionAction = "skip" +) + +// ExtractionCandidate 表示提供给模型做语义去重的既有记忆快照。 +type ExtractionCandidate struct { + Ref string `json:"ref"` + Scope Scope `json:"scope"` + Type Type `json:"type"` + Source string `json:"source"` + Title string `json:"title"` + Content string `json:"content"` +} + +// ExtractionDecision 表示模型针对新旧记忆关系返回的结构化决策。 +type ExtractionDecision struct { + Action ExtractionAction + Ref string + Entry Entry +} + // Store 定义记忆持久化的最小抽象。 type Store interface { LoadIndex(ctx context.Context, scope Scope) (*Index, error) @@ -104,6 +133,15 @@ type Extractor interface { Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) } +// DecisionExtractor 定义带既有记忆快照的语义提取能力。 +type DecisionExtractor interface { + ExtractDecisions( + ctx context.Context, + messages []providertypes.Message, + existing []ExtractionCandidate, + ) ([]ExtractionDecision, error) +} + // TextGenerator 定义调用 LLM 生成文本的最小能力,用于记忆提取。 // 该接口隔离 provider 细节,避免 memo 包直接依赖 provider 基础设施。 type TextGenerator interface { diff --git a/internal/runtime/memo.go b/internal/runtime/memo.go index 477933d7..d2ce6202 100644 --- a/internal/runtime/memo.go +++ b/internal/runtime/memo.go @@ -19,6 +19,30 @@ func (s *Service) triggerMemoExtraction(sessionID string, messages []providertyp s.memoExtractor.Schedule(sessionID, cloneMessages(messages)) } +// runBoundaryMessagesForMemo 返回当前 run 边界内的消息切片,供自动记忆提取使用。 +func runBoundaryMessagesForMemo(state *runState) []providertypes.Message { + if state == nil { + return nil + } + state.mu.Lock() + defer state.mu.Unlock() + return cloneMessages(state.memoRunMessages) +} + +// appendMemoRunMessage 记录当前 run 内已成功写入 transcript 的消息,作为自动记忆提取边界。 +func appendMemoRunMessage(state *runState, message providertypes.Message) { + if state == nil || message.IsEmpty() { + return + } + cloned := cloneMessages([]providertypes.Message{message}) + if len(cloned) == 0 { + return + } + state.mu.Lock() + defer state.mu.Unlock() + state.memoRunMessages = append(state.memoRunMessages, cloned[0]) +} + // isSuccessfulRememberToolCall 判断工具调用是否成功完成显式记忆写入。 func isSuccessfulRememberToolCall(callName string, result tools.ToolResult, execErr error) bool { if execErr != nil || result.IsError { diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 6887ee55..4fb3d22d 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -472,7 +472,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { }) recordAcceptanceTerminal(&state, acceptanceDecision) s.emitRunScoped(ctx, EventAgentDone, &state, turnOutput.assistant) - s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun) + s.triggerMemoExtraction(state.session.ID, runBoundaryMessagesForMemo(&state), state.rememberedThisRun) return nil case acceptance.AcceptanceContinue: state.lastAcceptanceBlockedReason = strings.TrimSpace(acceptanceDecision.CompletionBlockedReason) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index f1db6601..a8f72459 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -145,7 +145,7 @@ type ProviderFactory interface { // MemoExtractor 定义 runtime 层调用记忆提取的最小能力。 // 通过接口注入避免 runtime 直接依赖 memo 子系统实现细节。 type MemoExtractor interface { - // Schedule 从消息中安排一次后台记忆提取,失败由实现方自行处理。 + // Schedule 从当前 run 边界内的消息安排一次后台记忆提取,失败由实现方自行处理。 Schedule(sessionID string, messages []providertypes.Message) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 9a38cc91..5e1ca56d 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -1180,6 +1180,52 @@ func TestServiceRunSchedulesMemoExtractionAfterFinalReply(t *testing.T) { } } +func TestServiceRunSchedulesMemoExtractionFromCurrentRunBoundary(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + store := newMemoryStore() + session := agentsession.New("existing") + session.ID = "session-existing-memo" + session.Messages = []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old user")}}, + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old assistant")}}, + } + store.sessions[session.ID] = cloneSession(session) + + scripted := &scriptedProvider{ + streams: [][]providertypes.StreamEvent{ + { + providertypes.NewTextDeltaStreamEvent("new final"), + providertypes.NewMessageDoneStreamEvent("stop", nil), + }, + }, + } + service := NewWithFactory(manager, tools.NewRegistry(), store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + memoExtractor := &stubScheduledMemoExtractor{} + service.SetMemoExtractor(memoExtractor) + + err := service.Run(context.Background(), UserInput{ + SessionID: session.ID, + RunID: "run-memo-boundary", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("new user")}, + }) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if len(memoExtractor.calls) != 1 { + t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls)) + } + messages := memoExtractor.calls[0].messages + if len(messages) != 2 { + t.Fatalf("scheduled messages = %#v, want current run user+assistant only", messages) + } + if renderPartsForVerification(messages[0].Parts) != "new user" || + renderPartsForVerification(messages[1].Parts) != "new final" { + t.Fatalf("scheduled messages crossed run boundary: %#v", messages) + } +} + func TestServiceRunSkipsAutoMemoExtractionAfterRememberTool(t *testing.T) { t.Parallel() @@ -5555,6 +5601,8 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) { } service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder) + memoExtractor := &stubScheduledMemoExtractor{} + service.SetMemoExtractor(memoExtractor) service.compactRunner = &stubCompactRunner{ result: contextcompact.Result{ Messages: []providertypes.Message{ @@ -5617,6 +5665,16 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) { if renderPartsForTest(saved.Messages[2].Parts) != "recovered" { t.Fatalf("expected final assistant reply %q, got %q", "recovered", renderPartsForTest(saved.Messages[2].Parts)) } + if len(memoExtractor.calls) != 1 { + t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls)) + } + memoMessages := memoExtractor.calls[0].messages + if len(memoMessages) != 2 { + t.Fatalf("memo messages = %+v, want current run user+assistant only after compact", memoMessages) + } + if renderPartsForTest(memoMessages[0].Parts) != "continue" || renderPartsForTest(memoMessages[1].Parts) != "recovered" { + t.Fatalf("memo messages crossed compact boundary: %+v", memoMessages) + } events := collectRuntimeEvents(service.Events()) assertEventSequence(t, events, []EventType{ diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index 26d3d9bb..9f00db6f 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -29,6 +29,7 @@ func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, }); err != nil { return err } + appendMemoRunMessage(state, message) s.emitRunScoped(ctx, EventUserMessage, state, message) return nil } @@ -87,7 +88,7 @@ func (s *Service) appendAssistantMessageOnlyAndSave( } state.session.Messages = append(state.session.Messages, assistant) state.touchSession() - return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ + if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ SessionID: state.session.ID, Messages: []providertypes.Message{assistant}, UpdatedAt: state.session.UpdatedAt, @@ -95,7 +96,11 @@ func (s *Service) appendAssistantMessageOnlyAndSave( Model: state.session.Model, Workdir: state.session.Workdir, HasUnknownUsage: state.session.HasUnknownUsage, - }) + }); err != nil { + return err + } + appendMemoRunMessage(state, assistant) + return nil } // appendSystemMessageAndSave 将系统提醒消息追加到会话并持久化。 @@ -111,7 +116,7 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat } state.session.Messages = append(state.session.Messages, message) state.touchSession() - return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ + if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ SessionID: state.session.ID, Messages: []providertypes.Message{message}, UpdatedAt: state.session.UpdatedAt, @@ -119,7 +124,11 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat Model: state.session.Model, Workdir: state.session.Workdir, HasUnknownUsage: state.session.HasUnknownUsage, - }) + }); err != nil { + return err + } + appendMemoRunMessage(state, message) + return nil } // appendToolMessageAndSave 将工具原始结果写回会话,持久化时仅追加一条 tool message。 @@ -143,7 +152,11 @@ func (s *Service) appendToolMessageAndSave( HasUnknownUsage: state.session.HasUnknownUsage, } state.mu.Unlock() - return s.sessionStore.AppendMessages(ctx, input) + if err := s.sessionStore.AppendMessages(ctx, input); err != nil { + return err + } + appendMemoRunMessage(state, toolMessage) + return nil } // normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。 diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 2440bddf..277d9c6e 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -4,6 +4,7 @@ import ( "sync" "time" + providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/controlplane" "neo-code/internal/runtime/decider" runtimefacts "neo-code/internal/runtime/facts" @@ -20,6 +21,7 @@ type runState struct { effectiveWorkdir string compactCount int reactiveCompactAttempts int + memoRunMessages []providertypes.Message rememberedThisRun bool planningEnabled bool taskID string From ec23d015dc70cc85d92265d66e67bf17aae34cd0 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sun, 10 May 2026 10:47:56 +0800 Subject: [PATCH 4/5] =?UTF-8?q?fix:web=E7=AB=AFslash=E6=8C=87=E4=BB=A4?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/chat/ChatInput.test.tsx | 69 ++++++++++++++++++++++ web/src/components/chat/ChatInput.tsx | 25 ++++---- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index 8c610e7b..bf39b1b2 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -25,6 +25,12 @@ vi.mock('./ModelSelector', () => ({ default: () =>
, })) +async function submitSlashCommand(command: string) { + const textarea = screen.getByRole('textbox') as HTMLTextAreaElement + fireEvent.change(textarea, { target: { value: command } }) + fireEvent.keyDown(textarea, { key: 'Enter' }) +} + describe('ChatInput', () => { beforeEach(() => { vi.clearAllMocks() @@ -134,4 +140,67 @@ describe('ChatInput', () => { expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) + it('executes /memo without session id and shows payload.Content', async () => { + mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ + payload: { + Content: 'User Memo:\n- [user] coding preference', + }, + }) + render() + + await submitSlashCommand('/memo') + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).toHaveBeenCalledWith('', '', 'memo_list', {}) + }) + await waitFor(() => { + expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content.includes('coding preference'))).toBe(true) + }) + }) + + it('uses fallback text when memo payload has no content field', async () => { + mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ payload: {} }) + render() + + await submitSlashCommand('/memo') + + await waitFor(() => { + expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content === 'Memo query complete')).toBe(true) + }) + }) + + it('executes /remember and /forget without session id', async () => { + mockGatewayAPI.executeSystemTool + .mockResolvedValueOnce({ payload: { Content: 'Memory saved: [user] keep tests strict' } }) + .mockResolvedValueOnce({ payload: { Content: 'Removed 1 memo(s) matching \"strict\".' } }) + render() + + await submitSlashCommand('/remember keep tests strict') + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(1, '', '', 'memo_remember', { + type: 'user', + title: 'keep tests strict', + content: 'keep tests strict', + }) + }) + + await submitSlashCommand('/forget strict') + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(2, '', '', 'memo_remove', { + keyword: 'strict', + scope: 'all', + }) + }) + }) + + it('keeps argument validation for /remember and /forget', async () => { + render() + + await submitSlashCommand('/remember') + await submitSlashCommand('/forget') + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled() + }) + }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index ff4258e4..d5a4bb2e 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -57,6 +57,13 @@ function buildSlashHelpText(commands: AnySlashCommand[]): string { return ['可用命令:', ...lines].join('\n') } +/** 统一提取系统工具返回文本,兼容 payload.content 与 payload.Content。 */ +function extractSystemToolContent(result: unknown, fallback: string): string { + const payload = (result as { payload?: { content?: string; Content?: string } } | null)?.payload + const content = payload?.content ?? payload?.Content + return content || fallback +} + export default function ChatInput() { const gatewayAPI = useGatewayAPI() const text = useComposerStore((state) => state.composerText) @@ -149,13 +156,9 @@ export default function ChatInput() { return true } case '/memo': { - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Send a message first to start a session', 'error') - return true - } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {}) - addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete') + addSystemMessage(extractSystemToolContent(result, 'Memo query complete')) } catch (err) { console.error('Memo list failed:', err) useUIStore.getState().showToast('Failed to query memo', 'error') @@ -167,17 +170,13 @@ export default function ChatInput() { useUIStore.getState().showToast('Usage: /remember ', 'error') return true } - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Send a message first to start a session', 'error') - return true - } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { type: 'user', title: argument, content: argument, }) - addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved') + addSystemMessage(extractSystemToolContent(result, 'Memo saved')) } catch (err) { console.error('Remember failed:', err) useUIStore.getState().showToast('Failed to save memo', 'error') @@ -189,16 +188,12 @@ export default function ChatInput() { useUIStore.getState().showToast('Usage: /forget ', 'error') return true } - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Send a message first to start a session', 'error') - return true - } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { keyword: argument, scope: 'all', }) - addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted') + addSystemMessage(extractSystemToolContent(result, 'Memo deleted')) } catch (err) { console.error('Forget failed:', err) useUIStore.getState().showToast('Failed to delete memo', 'error') From 0a4e22765b138878e1d2c146e6cc948db5c64fed Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sun, 10 May 2026 16:16:15 +0800 Subject: [PATCH 5/5] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=9C=AC=E5=86=85=E5=AE=B9=EF=BC=8CAutoExtra?= =?UTF-8?q?ctor=20=E5=AF=B9=E6=AF=8F=E6=9D=A1=E6=96=B0=E5=80=99=E9=80=89?= =?UTF-8?q?=E5=8D=95=E7=8B=AC=E5=81=9A=E5=8E=BB=E9=87=8D=E5=86=B3=E7=AD=96?= =?UTF-8?q?=EF=BC=8C=E7=A7=BB=E9=99=A4extract=5Frecent=5Fmessages?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap.go | 2 +- internal/app/bootstrap_test.go | 1 - internal/config/config_test.go | 46 +-- internal/config/loader.go | 63 +++- internal/config/loader_test.go | 34 +- internal/config/memo.go | 35 +- internal/context/projection.go | 2 +- internal/context/projection_test.go | 6 + internal/memo/auto_extractor.go | 113 ++----- internal/memo/auto_extractor_test.go | 41 ++- internal/memo/llm_extractor.go | 220 ++++++++---- internal/memo/llm_extractor_test.go | 484 +++++++++------------------ internal/memo/semantic_candidates.go | 321 ++++++++++++++++++ internal/memo/service.go | 40 ++- internal/memo/service_test.go | 91 ++++- internal/memo/types.go | 23 +- internal/tools/memo/remember_test.go | 7 +- 17 files changed, 948 insertions(+), 581 deletions(-) create mode 100644 internal/memo/semantic_candidates.go diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index f7c6ce5f..41bea565 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -115,7 +115,7 @@ func newMemoExtractorAdapter( }) }) - scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator, cfg.Memo.ExtractRecentMessages)) + scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator)) }) } diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index 917823cb..a1af4c49 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -1679,7 +1679,6 @@ func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) { t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token") cfg := config.StaticDefaults().Clone() cfg.SelectedProvider = config.OpenAIName - cfg.Memo.ExtractRecentMessages = 3 manager := config.NewManager(config.NewLoader("", &cfg)) providerStub := &stubMemoProvider{ diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4e5b8c39..149307f9 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1495,12 +1495,11 @@ func TestMemoConfigClone(t *testing.T) { t.Parallel() original := MemoConfig{ - Enabled: true, - AutoExtract: false, - MaxEntries: 100, - MaxIndexBytes: 2048, - ExtractTimeoutSec: 9, - ExtractRecentMessages: 3, + Enabled: true, + AutoExtract: false, + MaxEntries: 100, + MaxIndexBytes: 2048, + ExtractTimeoutSec: 9, } cloned := original.Clone() if cloned != original { @@ -1518,10 +1517,9 @@ func TestMemoConfigApplyDefaults(t *testing.T) { t.Run("fills zero fields", func(t *testing.T) { cfg := MemoConfig{} cfg.ApplyDefaults(MemoConfig{ - MaxEntries: DefaultMemoMaxEntries, - MaxIndexBytes: DefaultMemoMaxIndexBytes, - ExtractTimeoutSec: DefaultMemoExtractTimeoutSec, - ExtractRecentMessages: DefaultMemoExtractRecentMessage, + MaxEntries: DefaultMemoMaxEntries, + MaxIndexBytes: DefaultMemoMaxIndexBytes, + ExtractTimeoutSec: DefaultMemoExtractTimeoutSec, }) if cfg.MaxEntries != DefaultMemoMaxEntries { t.Errorf("MaxEntries = %d, want %d", cfg.MaxEntries, DefaultMemoMaxEntries) @@ -1532,33 +1530,28 @@ func TestMemoConfigApplyDefaults(t *testing.T) { if cfg.ExtractTimeoutSec != DefaultMemoExtractTimeoutSec { t.Errorf("ExtractTimeoutSec = %d, want %d", cfg.ExtractTimeoutSec, DefaultMemoExtractTimeoutSec) } - if cfg.ExtractRecentMessages != DefaultMemoExtractRecentMessage { - t.Errorf("ExtractRecentMessages = %d, want %d", cfg.ExtractRecentMessages, DefaultMemoExtractRecentMessage) - } }) t.Run("preserves explicit fields", func(t *testing.T) { cfg := MemoConfig{ - MaxEntries: 50, - MaxIndexBytes: 1024, - ExtractTimeoutSec: 30, - ExtractRecentMessages: 5, + MaxEntries: 50, + MaxIndexBytes: 1024, + ExtractTimeoutSec: 30, } cfg.ApplyDefaults(defaultMemoConfig()) - if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 || cfg.ExtractRecentMessages != 5 { + if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 { t.Fatalf("ApplyDefaults() unexpectedly overwrote explicit values: %+v", cfg) } }) t.Run("preserves negative fields for validation", func(t *testing.T) { cfg := MemoConfig{ - MaxEntries: -1, - MaxIndexBytes: -2, - ExtractTimeoutSec: -3, - ExtractRecentMessages: -4, + MaxEntries: -1, + MaxIndexBytes: -2, + ExtractTimeoutSec: -3, } cfg.ApplyDefaults(defaultMemoConfig()) - if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 || cfg.ExtractRecentMessages != -4 { + if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 { t.Fatalf("ApplyDefaults() unexpectedly rewrote invalid values: %+v", cfg) } }) @@ -1603,13 +1596,6 @@ func TestMemoConfigValidate(t *testing.T) { } }) - t.Run("non-positive ExtractRecentMessages", func(t *testing.T) { - cfg := defaultMemoConfig() - cfg.ExtractRecentMessages = 0 - if err := cfg.Validate(); err == nil { - t.Fatal("non-positive ExtractRecentMessages should fail validation") - } - }) } func TestNormalizeWorkdirEdgeCases(t *testing.T) { diff --git a/internal/config/loader.go b/internal/config/loader.go index 601d214c..a07a689c 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -66,12 +66,11 @@ type persistedAskConfig struct { } type persistedMemoConfig struct { - Enabled *bool `yaml:"enabled,omitempty"` - AutoExtract *bool `yaml:"auto_extract,omitempty"` - MaxEntries *int `yaml:"max_entries,omitempty"` - MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"` - ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"` - ExtractRecentMessages *int `yaml:"extract_recent_messages,omitempty"` + Enabled *bool `yaml:"enabled,omitempty"` + AutoExtract *bool `yaml:"auto_extract,omitempty"` + MaxEntries *int `yaml:"max_entries,omitempty"` + MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"` + ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"` } func NewLoader(baseDir string, defaults *Config) *Loader { @@ -225,6 +224,9 @@ func parseConfigWithContextDefaults( } func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults MemoConfig) (*Config, error) { + if err := rejectRemovedMemoFields(data); err != nil { + return nil, err + } var file persistedConfig decoder := yaml.NewDecoder(bytes.NewReader(data)) decoder.KnownFields(true) @@ -384,14 +386,12 @@ func newPersistedMemoConfig(cfg MemoConfig) persistedMemoConfig { maxEntries := cfg.MaxEntries maxIndexBytes := cfg.MaxIndexBytes extractTimeoutSec := cfg.ExtractTimeoutSec - extractRecentMessages := cfg.ExtractRecentMessages return persistedMemoConfig{ - Enabled: &enabled, - AutoExtract: &autoExtract, - MaxEntries: &maxEntries, - MaxIndexBytes: &maxIndexBytes, - ExtractTimeoutSec: &extractTimeoutSec, - ExtractRecentMessages: &extractRecentMessages, + Enabled: &enabled, + AutoExtract: &autoExtract, + MaxEntries: &maxEntries, + MaxIndexBytes: &maxIndexBytes, + ExtractTimeoutSec: &extractTimeoutSec, } } @@ -413,12 +413,43 @@ func fromPersistedMemoConfig(file persistedMemoConfig, defaults MemoConfig) Memo if file.ExtractTimeoutSec != nil { out.ExtractTimeoutSec = *file.ExtractTimeoutSec } - if file.ExtractRecentMessages != nil { - out.ExtractRecentMessages = *file.ExtractRecentMessages - } return out } +// rejectRemovedMemoFields 在 strict decode 前拦截已删除的 memo 字段,输出明确迁移提示。 +func rejectRemovedMemoFields(data []byte) error { + var root yaml.Node + if err := yaml.Unmarshal(data, &root); err != nil { + return err + } + if len(root.Content) == 0 { + return nil + } + doc := root.Content[0] + if doc.Kind != yaml.MappingNode { + return nil + } + + for i := 0; i < len(doc.Content); i += 2 { + if strings.TrimSpace(doc.Content[i].Value) != "memo" { + continue + } + memoNode := doc.Content[i+1] + if memoNode.Kind != yaml.MappingNode { + return nil + } + for j := 0; j < len(memoNode.Content); j += 2 { + if strings.TrimSpace(memoNode.Content[j].Value) == "extract_recent_messages" { + return fmt.Errorf( + "config: memo.extract_recent_messages has been removed; memory extraction now always uses the full run boundary", + ) + } + } + return nil + } + return nil +} + // normalizeVerificationSchemaContent 在内存中预处理 verification schema,避免旧字段先于 strict decode 触发硬失败。 func normalizeVerificationSchemaContent(raw []byte) ([]byte, bool, error) { if len(bytes.TrimSpace(raw)) == 0 { diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 77cec06f..65d14c6f 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1930,7 +1930,6 @@ memo: max_entries: 123 max_index_bytes: 4096 extract_timeout_sec: 9 - extract_recent_messages: 4 ` writeLoaderConfig(t, loader, raw) @@ -1953,9 +1952,6 @@ memo: if cfg.Memo.ExtractTimeoutSec != 9 { t.Fatalf("expected memo.extract_timeout_sec=9, got %d", cfg.Memo.ExtractTimeoutSec) } - if cfg.Memo.ExtractRecentMessages != 4 { - t.Fatalf("expected memo.extract_recent_messages=4, got %d", cfg.Memo.ExtractRecentMessages) - } data, err := os.ReadFile(loader.ConfigPath()) if err != nil { @@ -2001,9 +1997,6 @@ shell: powershell if cfg.Memo.ExtractTimeoutSec <= 0 { t.Fatalf("expected memo.extract_timeout_sec to be defaulted, got %d", cfg.Memo.ExtractTimeoutSec) } - if cfg.Memo.ExtractRecentMessages <= 0 { - t.Fatalf("expected memo.extract_recent_messages to be defaulted, got %d", cfg.Memo.ExtractRecentMessages) - } } func TestLoaderRejectsLegacyMemoMaxIndexLinesField(t *testing.T) { @@ -2028,6 +2021,28 @@ memo: } } +func TestLoaderRejectsRemovedMemoExtractRecentMessagesField(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +memo: + extract_recent_messages: 4 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err == nil { + t.Fatalf("expected removed memo field to be rejected, cfg=%+v", cfg) + } + if !strings.Contains(err.Error(), "memo.extract_recent_messages has been removed") { + t.Fatalf("expected migration hint for extract_recent_messages, got %v", err) + } +} + func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) { t.Parallel() @@ -2051,11 +2066,6 @@ func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) { fieldYAML: "extract_timeout_sec: -1", errContain: "config: memo: extract_timeout_sec must be greater than 0", }, - { - name: "negative extract_recent_messages", - fieldYAML: "extract_recent_messages: -1", - errContain: "config: memo: extract_recent_messages must be greater than 0", - }, } for _, tt := range tests { diff --git a/internal/config/memo.go b/internal/config/memo.go index 9da2e2d3..010c9acf 100644 --- a/internal/config/memo.go +++ b/internal/config/memo.go @@ -3,31 +3,28 @@ package config import "errors" const ( - DefaultMemoMaxEntries = 200 - DefaultMemoMaxIndexBytes = 16 * 1024 - DefaultMemoExtractTimeoutSec = 15 - DefaultMemoExtractRecentMessage = 10 + DefaultMemoMaxEntries = 200 + DefaultMemoMaxIndexBytes = 16 * 1024 + DefaultMemoExtractTimeoutSec = 15 ) // MemoConfig 控制跨会话持久记忆的行为配置。 type MemoConfig struct { - Enabled bool `yaml:"enabled,omitempty"` - AutoExtract bool `yaml:"auto_extract,omitempty"` - MaxEntries int `yaml:"max_entries,omitempty"` - MaxIndexBytes int `yaml:"max_index_bytes,omitempty"` - ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"` - ExtractRecentMessages int `yaml:"extract_recent_messages,omitempty"` + Enabled bool `yaml:"enabled,omitempty"` + AutoExtract bool `yaml:"auto_extract,omitempty"` + MaxEntries int `yaml:"max_entries,omitempty"` + MaxIndexBytes int `yaml:"max_index_bytes,omitempty"` + ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"` } // defaultMemoConfig 返回跨会话记忆的默认配置。 func defaultMemoConfig() MemoConfig { return MemoConfig{ - Enabled: true, - AutoExtract: true, - MaxEntries: DefaultMemoMaxEntries, - MaxIndexBytes: DefaultMemoMaxIndexBytes, - ExtractTimeoutSec: DefaultMemoExtractTimeoutSec, - ExtractRecentMessages: DefaultMemoExtractRecentMessage, + Enabled: true, + AutoExtract: true, + MaxEntries: DefaultMemoMaxEntries, + MaxIndexBytes: DefaultMemoMaxIndexBytes, + ExtractTimeoutSec: DefaultMemoExtractTimeoutSec, } } @@ -50,9 +47,6 @@ func (c *MemoConfig) ApplyDefaults(defaults MemoConfig) { if c.ExtractTimeoutSec == 0 { c.ExtractTimeoutSec = defaults.ExtractTimeoutSec } - if c.ExtractRecentMessages == 0 { - c.ExtractRecentMessages = defaults.ExtractRecentMessages - } } // Validate 校验 memo 配置是否合法。 @@ -66,8 +60,5 @@ func (c MemoConfig) Validate() error { if c.ExtractTimeoutSec <= 0 { return errors.New("extract_timeout_sec must be greater than 0") } - if c.ExtractRecentMessages <= 0 { - return errors.New("extract_recent_messages must be greater than 0") - } return nil } diff --git a/internal/context/projection.go b/internal/context/projection.go index 298f7e8c..265beb0c 100644 --- a/internal/context/projection.go +++ b/internal/context/projection.go @@ -88,7 +88,7 @@ func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []pro keep := make([]bool, len(messages)) for index := 0; index < len(messages); index++ { message := messages[index] - if message.Role == providertypes.RoleTool { + if message.Role == providertypes.RoleTool || message.Role == providertypes.RoleSystem { continue } diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go index 3b0c932a..80edb854 100644 --- a/internal/context/projection_test.go +++ b/internal/context/projection_test.go @@ -145,6 +145,7 @@ func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T) messages := []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}, + {Role: providertypes.RoleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart("must call todo_write")}}, {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan")}}, { Role: providertypes.RoleAssistant, @@ -174,6 +175,11 @@ func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T) if renderDisplayParts(projected[0].Parts) != "first" || renderDisplayParts(projected[3].Parts) != "last" { t.Fatalf("expected full run user messages to remain, got %+v", projected) } + for _, message := range projected { + if message.Role == providertypes.RoleSystem { + t.Fatalf("system reminder should be excluded from memo extraction window: %+v", projected) + } + } if projected[1].Role != providertypes.RoleAssistant || len(projected[1].ToolCalls) != 1 { t.Fatalf("expected complete assistant tool span, got %+v", projected[1]) } diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go index a1ccb43b..812d64e8 100644 --- a/internal/memo/auto_extractor.go +++ b/internal/memo/auto_extractor.go @@ -149,7 +149,6 @@ func (a *AutoExtractor) handleDebounce(sessionID string, state *autoExtractState req := *state.pending state.pending = nil - // 增量检测:消息未变化时跳过提取 fp := computeMessageFingerprint(req.messages) if fp == state.lastFingerprint { a.armIdleTimerLocked(sessionID, state) @@ -183,7 +182,7 @@ func (a *AutoExtractor) handleRunDone(sessionID string, state *autoExtractState) a.armIdleTimerLocked(sessionID, state) } -// armIdleTimerLocked 在会话空闲时安排状态回收,避免 map 与 goroutine 长期累积。 +// armIdleTimerLocked 在会话空闲时安排状态回收,避免 map 中 goroutine 长期累积。 func (a *AutoExtractor) armIdleTimerLocked(sessionID string, state *autoExtractState) { stopTimer(state.idleTimer) state.idleSeq++ @@ -203,7 +202,6 @@ func (a *AutoExtractor) handleIdle(sessionID string, state *autoExtractState, se a.mu.Lock() defer a.mu.Unlock() - // 在删除前再次确认 map 中仍指向当前状态,防止旧回调回收新状态。 if a.states[sessionID] != state { return } @@ -219,30 +217,11 @@ func isIdleStateLocked(state *autoExtractState, seq uint64) bool { return state.idleSeq == seq && !state.running && state.pending == nil } -// extractAndStore 执行提取,并在写入前做本地批次去重和持久化级别的原子去重。 -// 返回值表示本次提取和写入流程是否成功完成,可用于更新增量指纹。 +// extractAndStore 执行提取,并在写入前完成本地去重和语义决策。 func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []providertypes.Message) bool { ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout) defer cancel() - if decisionExtractor, ok := extractor.(DecisionExtractor); ok { - existing, err := a.svc.autoExtractionCandidates(ctx) - if err != nil { - a.logError("memo: auto extract load candidates failed: %v", err) - return false - } - decisions, err := decisionExtractor.ExtractDecisions(ctx, messages, existing) - if err != nil { - if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { - a.logError("memo: auto extract skipped (protocol_mismatch): %v", err) - return true - } - a.logError("memo: auto extract failed: %v", err) - return false - } - return a.applyExtractionDecisions(ctx, decisions) - } - entries, err := extractor.Extract(ctx, messages) if err != nil { if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { @@ -256,6 +235,7 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider return true } + resolver, _ := extractor.(DecisionResolver) seen := make(map[string]struct{}, len(entries)) succeeded := true for _, entry := range entries { @@ -267,60 +247,55 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider if _, exists := seen[key]; exists { continue } + seen[key] = struct{}{} + + if resolver == nil { + if _, err := a.svc.addAutoExtractIfAbsent(ctx, entry); err != nil { + a.logError("memo: auto extract add failed: %v", err) + succeeded = false + } + continue + } - added, err := a.svc.addAutoExtractIfAbsent(ctx, entry) + shortlist, err := a.svc.semanticCandidateShortlist(ctx, entry, semanticCandidateShortlistLimit) if err != nil { - a.logError("memo: auto extract add failed: %v", err) + a.logError("memo: auto extract shortlist failed: %v", err) succeeded = false continue } - - seen[key] = struct{}{} - if !added { + if len(shortlist) == 0 { + if _, err := a.svc.addAutoExtractIfAbsent(ctx, entry); err != nil { + a.logError("memo: auto extract add failed: %v", err) + succeeded = false + } continue } - } - return succeeded -} -// applyExtractionDecisions 应用模型返回的 create/update/skip 决策,并保留本地精确去重兜底。 -func (a *AutoExtractor) applyExtractionDecisions(ctx context.Context, decisions []ExtractionDecision) bool { - if len(decisions) == 0 { - return true - } + decision, err := resolver.ResolveDecision(ctx, entry, shortlist) + if err != nil { + if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) { + a.logError("memo: auto extract skipped (protocol_mismatch): %v", err) + return true + } + a.logError("memo: auto extract resolve failed: %v", err) + succeeded = false + continue + } - seenCreates := make(map[string]struct{}, len(decisions)) - succeeded := true - for _, decision := range decisions { switch decision.Action { case ExtractionActionCreate: - entry := decision.Entry - entry.Source = SourceAutoExtract - key := autoExtractDedupKey(entry) - if key == "" { - continue - } - if _, exists := seenCreates[key]; exists { - continue - } - added, err := a.svc.addAutoExtractIfAbsent(ctx, entry) - if err != nil { + createEntry := decision.Entry + createEntry.Source = SourceAutoExtract + if _, err := a.svc.addAutoExtractIfAbsent(ctx, createEntry); err != nil { a.logError("memo: auto extract add failed: %v", err) succeeded = false - continue - } - seenCreates[key] = struct{}{} - if !added { - continue } case ExtractionActionUpdate: - entry := decision.Entry - entry.Source = SourceAutoExtract - _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, entry) - if err != nil { + updateEntry := decision.Entry + updateEntry.Source = SourceAutoExtract + if _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, updateEntry); err != nil { a.logError("memo: auto extract update failed: %v", err) succeeded = false - continue } case ExtractionActionSkip: continue @@ -341,21 +316,8 @@ func autoExtractDedupKey(entry Entry) string { // parseTopicSourceAndContent 从 topic 文件中解析 source frontmatter 和正文内容。 func parseTopicSourceAndContent(topic string) (string, string) { - parts := strings.Split(topic, "---") - if len(parts) < 3 { - return "", strings.TrimSpace(topic) - } - - frontmatter := parts[1] - body := strings.TrimSpace(strings.Join(parts[2:], "---")) - for _, line := range strings.Split(frontmatter, "\n") { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "source:") { - continue - } - return strings.TrimSpace(strings.TrimPrefix(line, "source:")), body - } - return "", body + source, _, body := parseTopicSnapshot(topic) + return source, body } // cloneProviderMessages 深拷贝消息切片,避免后台任务读取到后续修改。 @@ -400,7 +362,6 @@ func computeMessageFingerprint(messages []providertypes.Message) uint64 { return h.Sum64() } - // 回退路径仅在意外序列化失败时使用,至少保留文本消息的增量检测能力。 for _, msg := range messages { _, _ = h.Write([]byte(msg.Role)) _, _ = h.Write([]byte{0}) diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go index 6c4e5835..06e012c9 100644 --- a/internal/memo/auto_extractor_test.go +++ b/internal/memo/auto_extractor_test.go @@ -39,30 +39,38 @@ func (s *stubMemoExtractor) Calls() int { } type stubDecisionMemoExtractor struct { - mu sync.Mutex - callCount int - candidates []ExtractionCandidate - decisions []ExtractionDecision - err error + mu sync.Mutex + callCount int + candidates []ExtractionCandidate + extractEntries []Entry + decisions []ExtractionDecision + err error } func (s *stubDecisionMemoExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { - return nil, nil + s.mu.Lock() + defer s.mu.Unlock() + return append([]Entry(nil), s.extractEntries...), nil } -func (s *stubDecisionMemoExtractor) ExtractDecisions( +func (s *stubDecisionMemoExtractor) ResolveDecision( ctx context.Context, - messages []providertypes.Message, + candidate Entry, existing []ExtractionCandidate, -) ([]ExtractionDecision, error) { +) (ExtractionDecision, error) { s.mu.Lock() defer s.mu.Unlock() s.callCount++ s.candidates = append([]ExtractionCandidate(nil), existing...) if s.err != nil { - return nil, s.err + return ExtractionDecision{}, s.err } - return append([]ExtractionDecision(nil), s.decisions...), nil + if len(s.decisions) == 0 { + return ExtractionDecision{Action: ExtractionActionCreate, Entry: candidate}, nil + } + decision := s.decisions[0] + s.decisions = append([]ExtractionDecision(nil), s.decisions[1:]...) + return decision, nil } func newAutoExtractorTestService(t *testing.T) *Service { @@ -278,6 +286,9 @@ func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing } extractor := &stubDecisionMemoExtractor{ + extractEntries: []Entry{ + {Type: TypeFeedback, Title: "测试策略", Content: "用户要求修改后先跑相关测试。"}, + }, decisions: []ExtractionDecision{ { Action: ExtractionActionUpdate, @@ -324,8 +335,8 @@ func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing if len(manualRecall) != 1 || strings.Contains(manualRecall[0].Content, "不应覆盖") { t.Fatalf("manual memory should not be overwritten, got %+v", manualRecall) } - if len(extractor.candidates) != 2 { - t.Fatalf("expected existing candidates to be provided, got %+v", extractor.candidates) + if len(extractor.candidates) != 1 || extractor.candidates[0].Ref != autoRef { + t.Fatalf("expected shortlist to target the auto-extracted memory, got %+v", extractor.candidates) } } @@ -341,6 +352,10 @@ func TestAutoExtractorSemanticCreateStillUsesExactDedup(t *testing.T) { } extractor := &stubDecisionMemoExtractor{ + extractEntries: []Entry{ + {Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}, + {Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}, + }, decisions: []ExtractionDecision{ {Action: ExtractionActionCreate, Entry: Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}}, {Action: ExtractionActionCreate, Entry: Entry{Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}}, diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go index 621446de..ba239fe5 100644 --- a/internal/memo/llm_extractor.go +++ b/internal/memo/llm_extractor.go @@ -13,17 +13,16 @@ import ( ) var ( - // ErrExtractionNoJSONArray 表示提取结果中找不到合法 JSON 数组起始。 - ErrExtractionNoJSONArray = errors.New("memo: extraction response does not contain a JSON array") - // ErrExtractionIncompleteJSONArray 表示提取结果中的 JSON 数组不完整。 - ErrExtractionIncompleteJSONArray = errors.New("memo: extraction response contains an incomplete JSON array") + // ErrExtractionNoJSONArray 表示提取结果中找不到合法 JSON 数组或对象起始。 + ErrExtractionNoJSONArray = errors.New("memo: extraction response does not contain a JSON payload") + // ErrExtractionIncompleteJSONArray 表示提取结果中的 JSON 数组或对象不完整。 + ErrExtractionIncompleteJSONArray = errors.New("memo: extraction response contains an incomplete JSON payload") ) -// LLMExtractor 基于 LLM 分析最近对话,并返回结构化记忆条目。 +// LLMExtractor 基于 LLM 分析当前 run 对话,并输出结构化记忆或去重决策。 type LLMExtractor struct { - generator TextGenerator - now func() time.Time - recentMessageLimit int + generator TextGenerator + now func() time.Time } type extractedEntry struct { @@ -36,38 +35,15 @@ type extractedEntry struct { } // NewLLMExtractor 创建基于 TextGenerator 的记忆提取器。 -func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtractor { - if recentMessageLimit <= 0 { - recentMessageLimit = 10 - } +func NewLLMExtractor(generator TextGenerator) *LLMExtractor { return &LLMExtractor{ - generator: generator, - now: time.Now, - recentMessageLimit: recentMessageLimit, + generator: generator, + now: time.Now, } } // Extract 从当前 run 对话中提取可跨会话持久化的新增记忆条目。 func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { - decisions, err := e.ExtractDecisions(ctx, messages, nil) - if err != nil { - return nil, err - } - entries := make([]Entry, 0, len(decisions)) - for _, decision := range decisions { - if decision.Action == ExtractionActionCreate { - entries = append(entries, decision.Entry) - } - } - return entries, nil -} - -// ExtractDecisions 从当前 run 对话中提取记忆,并结合既有记忆输出新增、合并或跳过决策。 -func (e *LLMExtractor) ExtractDecisions( - ctx context.Context, - messages []providertypes.Message, - existing []ExtractionCandidate, -) ([]ExtractionDecision, error) { if err := ctx.Err(); err != nil { return nil, err } @@ -80,7 +56,7 @@ func (e *LLMExtractor) ExtractDecisions( return nil, nil } - response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now(), existing), runMessages) + response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now()), runMessages) if err != nil { return nil, err } @@ -98,33 +74,63 @@ func (e *LLMExtractor) ExtractDecisions( return nil, fmt.Errorf("memo: parse extraction response: %w", err) } - decisions := make([]ExtractionDecision, 0, len(extracted)) + entries := make([]Entry, 0, len(extracted)) for _, item := range extracted { - decision, ok := toExtractionDecision(item) + entry, ok := toMemoEntry(item) if !ok { continue } - decisions = append(decisions, decision) + entries = append(entries, entry) } - return decisions, nil + return entries, nil } -// buildExtractionPrompt 构造记忆提取专用的 system prompt。 -func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string { - currentDate := now.In(time.Local).Format("2006-01-02") - existingJSON := "[]" - if len(existing) > 0 { - if data, err := json.Marshal(existing); err == nil { - existingJSON = string(data) - } +// ResolveDecision 结合单条候选记忆与 shortlist,解析 create/update/skip 决策。 +func (e *LLMExtractor) ResolveDecision( + ctx context.Context, + candidate Entry, + existing []ExtractionCandidate, +) (ExtractionDecision, error) { + if err := ctx.Err(); err != nil { + return ExtractionDecision{}, err + } + if e == nil || e.generator == nil { + return ExtractionDecision{}, errors.New("memo: text generator is nil") + } + + response, err := e.generator.Generate(ctx, buildResolutionPrompt(e.now(), candidate, existing), buildResolutionMessages(candidate)) + if err != nil { + return ExtractionDecision{}, err + } + if err := ctx.Err(); err != nil { + return ExtractionDecision{}, err } + + jsonText, err := extractJSONObject(response) + if err != nil { + return ExtractionDecision{}, err + } + + var extracted extractedEntry + if err := json.Unmarshal([]byte(jsonText), &extracted); err != nil { + return ExtractionDecision{}, fmt.Errorf("memo: parse resolution response: %w", err) + } + decision, ok := toExtractionDecision(extracted) + if !ok { + return ExtractionDecision{}, errors.New("memo: resolution response is invalid") + } + return decision, nil +} + +// buildExtractionPrompt 构造仅负责“当前 run 提取”的 system prompt。 +func buildExtractionPrompt(now time.Time) string { + currentDate := now.In(time.Local).Format("2006-01-02") return strings.TrimSpace(fmt.Sprintf(` 你是一个记忆提取助手(memory extraction assistant)。 -分析当前 run 对话中值得跨会话持久记住的信息,并结合既有记忆完成语义去重,返回严格 JSON 数组。 +分析当前 run 对话中值得跨会话持久记住的信息,并返回严格 JSON 数组。 当前本地日期:%s 如果对话中出现“今天、明天、下周二”等相对日期,必须先转换为绝对日期再写入 content。 - 只允许以下四种 type: - user: 用户偏好、习惯、背景、专长 - feedback: 用户对 Agent 做法的纠正、要求、确认过的工作方式 @@ -136,18 +142,68 @@ func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string 2. 不要提取通用编程知识、代码结构、文件路径、Git 历史。 3. 每条记忆必须具体、可操作。 4. 没有值得记住的信息时,返回 []。 -5. 如果新信息与既有记忆语义相同或只是轻微改写,输出 action="skip"。 -6. 如果新信息能补充或修正既有 source="extractor_auto" 的记忆,输出 action="update" 并填写目标 ref。 -7. 不允许 update source 不是 "extractor_auto" 的既有记忆;这类相近内容只能 skip。 -8. 如果是全新的可持久化信息,输出 action="create"。 -9. 输出必须是 JSON 数组,不要输出任何额外解释。 +5. 输出必须是 JSON 数组,不要输出任何额外解释。 -既有记忆候选(JSON): +输出格式: +[{"type":"user","title":"...","content":"...","keywords":["..."]}] +`, currentDate)) +} + +// buildResolutionPrompt 构造单条候选记忆的去重决策提示。 +func buildResolutionPrompt(now time.Time, candidate Entry, existing []ExtractionCandidate) string { + currentDate := now.In(time.Local).Format("2006-01-02") + candidateJSON := marshalPromptJSON(struct { + Type Type `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Keywords []string `json:"keywords,omitempty"` + }{ + Type: candidate.Type, + Title: candidate.Title, + Content: candidate.Content, + Keywords: candidate.Keywords, + }) + existingJSON := marshalPromptJSON(existing) + return strings.TrimSpace(fmt.Sprintf(` +你是一个记忆去重决策助手(memory dedupe assistant)。 +分析一条新的候选记忆与已有记忆 shortlist,返回严格 JSON 对象。 + +当前本地日期:%s +如果候选内容中的日期是相对日期,先换算成绝对日期再判断。 + +规则: +1. 若候选记忆与 shortlist 中某条记忆语义相同,返回 {"action":"skip","ref":"..."}。 +2. 若候选记忆补充或修正了某条 source="extractor_auto" 的已有记忆,返回 {"action":"update","ref":"...","title":"...","content":"...","keywords":[...]}。 +3. 不允许更新 source 不是 "extractor_auto" 的已有记忆;遇到这类情况只能返回 skip。 +4. 若 shortlist 中没有合适对象,返回 {"action":"create","type":"...","title":"...","content":"...","keywords":[...]}。 +5. 输出必须是单个 JSON 对象,不要输出任何额外解释。 + +候选记忆(JSON): %s -输出格式: -[{"action":"create","type":"user","title":"...","content":"...","keywords":["..."]},{"action":"update","ref":"project:p.md","title":"...","content":"...","keywords":["..."]},{"action":"skip","ref":"user:u.md"}] -`, currentDate, existingJSON)) +已有记忆 shortlist(JSON): +%s +`, currentDate, candidateJSON, existingJSON)) +} + +// buildResolutionMessages 为去重决策构造最小 provider 输入,避免空消息列表触发兼容问题。 +func buildResolutionMessages(candidate Entry) []providertypes.Message { + text := fmt.Sprintf("请为这条候选记忆返回 create、update 或 skip 决策:%s", candidate.Title) + return []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(text)}, + }, + } +} + +// marshalPromptJSON 将结构化数据压缩为 prompt 内联 JSON。 +func marshalPromptJSON(value any) string { + data, err := json.Marshal(value) + if err != nil { + return "null" + } + return string(data) } // containsUserMessage 检查待提取消息中是否包含用户输入。 @@ -189,9 +245,13 @@ func toExtractionDecision(item extractedEntry) (ExtractionDecision, bool) { return ExtractionDecision{}, false } if action == ExtractionActionSkip { + ref := strings.TrimSpace(item.Ref) + if ref == "" { + return ExtractionDecision{}, false + } return ExtractionDecision{ Action: action, - Ref: strings.TrimSpace(item.Ref), + Ref: ref, }, true } @@ -312,3 +372,45 @@ func extractJSONArray(text string) (string, error) { return "", ErrExtractionIncompleteJSONArray } + +// extractJSONObject 从模型返回文本中提取最外层 JSON 对象。 +func extractJSONObject(text string) (string, error) { + start := strings.Index(text, "{") + if start < 0 { + return "", ErrExtractionNoJSONArray + } + + depth := 0 + inString := false + escaped := false + for index := start; index < len(text); index++ { + ch := text[index] + if inString { + if escaped { + escaped = false + continue + } + switch ch { + case '\\': + escaped = true + case '"': + inString = false + } + continue + } + + switch ch { + case '"': + inString = true + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + return strings.TrimSpace(text[start : index+1]), nil + } + } + } + + return "", ErrExtractionIncompleteJSONArray +} diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go index 06813229..987c59fb 100644 --- a/internal/memo/llm_extractor_test.go +++ b/internal/memo/llm_extractor_test.go @@ -32,12 +32,12 @@ func (s *stubTextGenerator) Generate( return s.response, nil } -// TestLLMExtractorExtractValidJSON 验证提取器可以解析合法 JSON 并收敛字段。 +// TestLLMExtractorExtractValidJSON 验证提取器可解析合法 JSON 并收敛字段。 func TestLLMExtractorExtractValidJSON(t *testing.T) { generator := &stubTextGenerator{ - response: `[{"type":"user","title":" 偏好 Go 代码风格 ","content":"用户偏好使用 Go 惯用写法。","keywords":["go"," style ","go"]}]`, + response: `[{"type":"user","title":" 偏好 Go 代码风格 ","content":"用户偏好使用 Go 惯用写法。","keywords":["go"," style ","go"]}]`, } - extractor := NewLLMExtractor(generator, 10) + extractor := NewLLMExtractor(generator) extractor.now = func() time.Time { return time.Date(2026, 4, 13, 10, 0, 0, 0, time.FixedZone("CST", 8*3600)) } @@ -70,219 +70,100 @@ func TestLLMExtractorExtractValidJSON(t *testing.T) { if generator.calls != 1 { t.Fatalf("Generate() calls = %d, want 1", generator.calls) } - if !strings.Contains(generator.prompt, "user: 用户偏好") { + if !strings.Contains(generator.prompt, "用户偏好") { t.Fatalf("prompt should describe user type, got %q", generator.prompt) } - if !strings.Contains(generator.prompt, "当前本地日期:2026-04-13") { + if !strings.Contains(generator.prompt, "2026-04-13") { t.Fatalf("prompt should include absolute local date, got %q", generator.prompt) } - if !strings.Contains(generator.prompt, "必须先转换为绝对日期") { - t.Fatalf("prompt should require absolute dates, got %q", generator.prompt) - } -} - -// TestLLMExtractorExtractEmptyResult 验证空数组响应会返回零条记忆。 -func TestLLMExtractorExtractEmptyResult(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}, 10) - - entries, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("这轮没有需要记住的内容。")}}, - }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(entries) != 0 { - t.Fatalf("len(entries) = %d, want 0", len(entries)) - } -} - -// TestLLMExtractorExtractNoUserMessage 验证没有用户消息时不会调用模型。 -func TestLLMExtractorExtractNoUserMessage(t *testing.T) { - generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) - - entries, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("只有助手消息。")}}, - }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(entries) != 0 { - t.Fatalf("len(entries) = %d, want 0", len(entries)) - } - if generator.calls != 0 { - t.Fatalf("Generate() calls = %d, want 0", generator.calls) - } -} - -// TestLLMExtractorExtractNoMessages 验证空消息输入直接返回空结果。 -func TestLLMExtractorExtractNoMessages(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}, 10) - - entries, err := extractor.Extract(context.Background(), nil) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(entries) != 0 { - t.Fatalf("len(entries) = %d, want 0", len(entries)) - } -} - -// TestLLMExtractorExtractInvalidJSON 验证无效 JSON 会返回错误。 -func TestLLMExtractorExtractInvalidJSON(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{response: `[{invalid json}]`}, 10) - - _, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, - }) - if err == nil { - t.Fatal("expected invalid JSON error") - } -} - -// TestLLMExtractorExtractToleratesWrappedJSON 验证 JSON 前后噪声不会影响解析。 -func TestLLMExtractorExtractToleratesWrappedJSON(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{ - response: "分析如下:\n[{\"type\":\"feedback\",\"title\":\"以后先跑测试\",\"content\":\"用户要求修改后先跑测试。\"}]\n以上完毕。", - }, 10) - - entries, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑测试。")}}, - }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(entries) != 1 || entries[0].Type != TypeFeedback { - t.Fatalf("entries = %#v", entries) - } } -// TestLLMExtractorExtractFiltersInvalidEntries 验证非法类型和空字段会被过滤。 -func TestLLMExtractorExtractFiltersInvalidEntries(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{ - response: `[ - {"type":"invalid","title":"bad","content":"bad"}, - {"type":"project","title":" ","content":"missing title"}, - {"type":"reference","title":"文档入口","content":"查看 docs/runtime-provider-event-flow.md"} - ]`, - }, 10) - - entries, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("参考文档在 docs/runtime-provider-event-flow.md。")}}, - }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(entries) != 1 || entries[0].Type != TypeReference { - t.Fatalf("entries = %#v", entries) +// TestLLMExtractorExtractSkipsEmptyOrNonUserInputs 验证缺少有效用户输入时不会调用模型。 +func TestLLMExtractorExtractSkipsEmptyOrNonUserInputs(t *testing.T) { + tests := []struct { + name string + messages []providertypes.Message + }{ + { + name: "nil messages", + messages: nil, + }, + { + name: "assistant only", + messages: []providertypes.Message{ + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("只有助手消息。")}}, + }, + }, + { + name: "image only user", + messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewRemoteImagePart("https://example.com/pic.png")}}, + }, + }, } -} - -// TestLLMExtractorExtractCancelledContext 验证已取消上下文会中止提取。 -func TestLLMExtractorExtractCancelledContext(t *testing.T) { - generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, err := extractor.Extract(ctx, []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, - }) - if !errors.Is(err, context.Canceled) { - t.Fatalf("Extract() error = %v, want context.Canceled", err) - } - if generator.calls != 0 { - t.Fatalf("Generate() calls = %d, want 0", generator.calls) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &stubTextGenerator{response: `[]`} + extractor := NewLLMExtractor(generator) + entries, err := extractor.Extract(context.Background(), tt.messages) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + if len(entries) != 0 { + t.Fatalf("len(entries) = %d, want 0", len(entries)) + } + if generator.calls != 0 { + t.Fatalf("Generate() calls = %d, want 0", generator.calls) + } + }) } } -// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息而非固定 recent window。 +// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息并排除 system/tool 噪声。 func TestLLMExtractorExtractUsesFullRunMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) + extractor := NewLLMExtractor(generator) - messages := make([]providertypes.Message, 0, 16) - for index := 0; index < 12; index++ { - messages = append(messages, providertypes.Message{ - Role: providertypes.RoleUser, - Parts: []providertypes.ContentPart{providertypes.NewTextPart("user-" + string(rune('a'+index)))}, - }) - if index%3 == 0 { - messages = append(messages, providertypes.Message{ - Role: providertypes.RoleTool, - Parts: []providertypes.ContentPart{providertypes.NewTextPart("tool-" + string(rune('a'+index)))}, - }) - } + messages := []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}, + {Role: providertypes.RoleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart("must call todo_write")}}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call_1", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")}, + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("last")}}, } _, err := extractor.Extract(context.Background(), messages) if err != nil { t.Fatalf("Extract() error = %v", err) } - if len(generator.messages) != 12 { - t.Fatalf("len(generator.messages) = %d, want 12", len(generator.messages)) + if len(generator.messages) != 4 { + t.Fatalf("len(generator.messages) = %d, want 4", len(generator.messages)) } for _, message := range generator.messages { - if message.Role == providertypes.RoleTool { - t.Fatalf("unexpected tool message in extraction context: %#v", message) + if message.Role == providertypes.RoleSystem { + t.Fatalf("system reminder should not enter extraction window: %#v", message) } } - if renderMemoParts(generator.messages[0].Parts) != "user-a" || - renderMemoParts(generator.messages[11].Parts) != "user-l" { - t.Fatalf("unexpected run window: first=%q last=%q", - renderMemoParts(generator.messages[0].Parts), - renderMemoParts(generator.messages[11].Parts)) - } -} - -func TestLLMExtractorExtractDecisionsIncludesExistingCandidates(t *testing.T) { - generator := &stubTextGenerator{ - response: `[{"action":"skip","ref":"user:u.md"},{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}]`, - } - extractor := NewLLMExtractor(generator, 10) - - decisions, err := extractor.ExtractDecisions( - context.Background(), - []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑相关测试。")}}, - }, - []ExtractionCandidate{ - { - Ref: "project:p.md", - Scope: ScopeProject, - Type: TypeFeedback, - Source: SourceAutoExtract, - Title: "测试策略", - Content: "用户要求修改后先跑测试。", - }, - }, - ) - if err != nil { - t.Fatalf("ExtractDecisions() error = %v", err) - } - if len(decisions) != 2 { - t.Fatalf("len(decisions) = %d, want 2", len(decisions)) - } - if decisions[0].Action != ExtractionActionSkip || decisions[0].Ref != "user:u.md" { - t.Fatalf("unexpected skip decision: %+v", decisions[0]) - } - if decisions[1].Action != ExtractionActionUpdate || decisions[1].Ref != "project:p.md" { - t.Fatalf("unexpected update decision: %+v", decisions[1]) - } - if decisions[1].Entry.Type != "" { - t.Fatalf("update decision should not require type, got %+v", decisions[1].Entry) - } - if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) { - t.Fatalf("prompt should include existing memory candidates, got %q", generator.prompt) - } - if !strings.Contains(generator.prompt, `action="update"`) || !strings.Contains(generator.prompt, `source="extractor_auto"`) { - t.Fatalf("prompt should describe semantic dedupe protocol, got %q", generator.prompt) + if renderMemoParts(generator.messages[0].Parts) != "first" || renderMemoParts(generator.messages[3].Parts) != "last" { + t.Fatalf("unexpected extraction window: %+v", generator.messages) } } +// TestLLMExtractorExtractDropsIncompleteToolCallSpan 验证不完整的 tool call span 会被剔除。 func TestLLMExtractorExtractDropsIncompleteToolCallSpan(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) + extractor := NewLLMExtractor(generator) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}, @@ -307,169 +188,134 @@ func TestLLMExtractorExtractDropsIncompleteToolCallSpan(t *testing.T) { } } -func TestLLMExtractorExtractKeepsProjectedToolCallSpan(t *testing.T) { - generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) +// TestLLMExtractorResolveDecisionUsesShortlist 验证决策阶段会携带 shortlist 并解析 update 结果。 +func TestLLMExtractorResolveDecisionUsesShortlist(t *testing.T) { + generator := &stubTextGenerator{ + response: `{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}`, + } + extractor := NewLLMExtractor(generator) - _, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("remember this")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + decision, err := extractor.ResolveDecision( + context.Background(), + Entry{Type: TypeFeedback, Title: "测试策略", Content: "以后修改完先跑相关测试。"}, + []ExtractionCandidate{ + { + Ref: "project:p.md", + Scope: ScopeProject, + Type: TypeFeedback, + Source: SourceAutoExtract, + Title: "测试策略", + Keywords: []string{"test"}, + Content: "用户要求修改后先跑测试。", }, }, - { - Role: providertypes.RoleTool, - ToolCallID: "call_1", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")}, - ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, - }, - }) + ) if err != nil { - t.Fatalf("Extract() error = %v", err) + t.Fatalf("ResolveDecision() error = %v", err) } - if len(generator.messages) != 3 { - t.Fatalf("len(generator.messages) = %d, want 3", len(generator.messages)) + if decision.Action != ExtractionActionUpdate || decision.Ref != "project:p.md" { + t.Fatalf("unexpected decision: %+v", decision) } - if generator.messages[1].Role != providertypes.RoleAssistant || len(generator.messages[1].ToolCalls) != 1 { - t.Fatalf("expected assistant tool call span to be preserved, got %#v", generator.messages[1]) + if decision.Entry.Type != "" { + t.Fatalf("update decision should not require type, got %+v", decision.Entry) } - toolMessage := generator.messages[2] - if toolMessage.Role != providertypes.RoleTool { - t.Fatalf("expected projected tool message, got %#v", toolMessage) + if len(generator.messages) != 1 || generator.messages[0].Role != providertypes.RoleUser { + t.Fatalf("expected one synthetic user message, got %+v", generator.messages) } - toolText := renderMemoParts(toolMessage.Parts) - if !strings.Contains(toolText, "tool result") || !strings.Contains(toolText, "tool: filesystem_read_file") { - t.Fatalf("expected projected tool text, got %q", toolText) + if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) { + t.Fatalf("prompt should include shortlist candidates, got %q", generator.prompt) } - if toolMessage.ToolMetadata != nil { - t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata) + if !strings.Contains(generator.prompt, `source="extractor_auto"`) { + t.Fatalf("prompt should describe auto-extracted update rule, got %q", generator.prompt) } } -func TestLLMExtractorExtractKeepsMetadataOnlyToolCallSpan(t *testing.T) { - generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) +// TestLLMExtractorResolveDecisionSupportsCreateAndSkip 验证单条决策支持 create 与 skip。 +func TestLLMExtractorResolveDecisionSupportsCreateAndSkip(t *testing.T) { + t.Run("create", func(t *testing.T) { + extractor := NewLLMExtractor(&stubTextGenerator{ + response: `{"action":"create","type":"project","title":"上线计划","content":"项目将在 2026-05-12 上线。","keywords":["release"]}`, + }) + decision, err := extractor.ResolveDecision( + context.Background(), + Entry{Type: TypeProject, Title: "上线计划", Content: "项目将在下周上线。"}, + []ExtractionCandidate{{Ref: "project:old.md", Scope: ScopeProject, Type: TypeProject, Source: SourceUserManual, Title: "旧计划", Content: "历史计划"}}, + ) + if err != nil { + t.Fatalf("ResolveDecision(create) error = %v", err) + } + if decision.Action != ExtractionActionCreate || decision.Entry.Type != TypeProject { + t.Fatalf("unexpected create decision: %+v", decision) + } + }) - _, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("remember this")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, - }, - }, - { - Role: providertypes.RoleTool, - ToolCallID: "call_1", - ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, - }, + t.Run("skip", func(t *testing.T) { + extractor := NewLLMExtractor(&stubTextGenerator{ + response: `{"action":"skip","ref":"user:u.md"}`, + }) + decision, err := extractor.ResolveDecision( + context.Background(), + Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}, + []ExtractionCandidate{{Ref: "user:u.md", Scope: ScopeUser, Type: TypeUser, Source: SourceUserManual, Title: "中文回复", Content: "用户偏好中文回复。"}}, + ) + if err != nil { + t.Fatalf("ResolveDecision(skip) error = %v", err) + } + if decision.Action != ExtractionActionSkip || decision.Ref != "user:u.md" { + t.Fatalf("unexpected skip decision: %+v", decision) + } }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(generator.messages) != 3 { - t.Fatalf("len(generator.messages) = %d, want 3", len(generator.messages)) - } - toolMessage := generator.messages[2] - if toolMessage.Role != providertypes.RoleTool { - t.Fatalf("expected projected tool message, got %#v", toolMessage) - } - toolText := renderMemoParts(toolMessage.Parts) - if !strings.Contains(toolText, "tool result") || - !strings.Contains(toolText, "tool: filesystem_read_file") || - !strings.Contains(toolText, "meta.path: README.md") { - t.Fatalf("expected metadata-only tool text, got %q", toolText) - } - if strings.Contains(toolText, "content:\n") { - t.Fatalf("expected metadata-only projection to omit content section, got %q", toolText) - } - if toolMessage.ToolMetadata != nil { - t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata) - } } -func TestLLMExtractorExtractSkipsOrphanAndClearedToolMessages(t *testing.T) { - generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) +// TestLLMExtractorErrors 验证取消、空生成器和上游错误会正确透传。 +func TestLLMExtractorErrors(t *testing.T) { + t.Run("canceled context", func(t *testing.T) { + generator := &stubTextGenerator{response: `[]`} + extractor := NewLLMExtractor(generator) + ctx, cancel := context.WithCancel(context.Background()) + cancel() - _, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("alpha")}}, - {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call_1", Name: "bash", Arguments: `{}`}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call_1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("[Old tool result content cleared]")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("beta")}}, - }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(generator.messages) != 2 { - t.Fatalf("len(generator.messages) = %d, want 2", len(generator.messages)) - } - for _, message := range generator.messages { - if message.Role == providertypes.RoleTool || len(message.ToolCalls) > 0 { - t.Fatalf("unexpected tool-related message in extraction window: %#v", message) + _, err := extractor.Extract(ctx, []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("Extract() error = %v, want context.Canceled", err) } - } -} - -func TestLLMExtractorExtractNilGenerator(t *testing.T) { - var extractor *LLMExtractor - _, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, }) - if err == nil || !strings.Contains(err.Error(), "text generator is nil") { - t.Fatalf("Extract() error = %v", err) - } - extractor = NewLLMExtractor(nil, 10) - _, err = extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, + t.Run("nil generator", func(t *testing.T) { + extractor := NewLLMExtractor(nil) + _, err := extractor.Extract(context.Background(), []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, + }) + if err == nil || !strings.Contains(err.Error(), "text generator is nil") { + t.Fatalf("Extract() error = %v", err) + } }) - if err == nil || !strings.Contains(err.Error(), "text generator is nil") { - t.Fatalf("Extract() error = %v", err) - } -} -func TestLLMExtractorExtractGeneratorFailure(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{err: errors.New("upstream failed")}, 10) - _, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, + t.Run("generator failure", func(t *testing.T) { + extractor := NewLLMExtractor(&stubTextGenerator{err: errors.New("upstream failed")}) + _, err := extractor.Extract(context.Background(), []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, + }) + if err == nil || !strings.Contains(err.Error(), "upstream failed") { + t.Fatalf("Extract() error = %v", err) + } }) - if err == nil || !strings.Contains(err.Error(), "upstream failed") { - t.Fatalf("Extract() error = %v", err) - } } -func TestExtractJSONArrayErrors(t *testing.T) { +// TestJSONPayloadExtractors 验证数组与对象提取器的错误分支。 +func TestJSONPayloadExtractors(t *testing.T) { if _, err := extractJSONArray("no json here"); err == nil || !strings.Contains(err.Error(), "does not contain") { t.Fatalf("expected missing array error, got %v", err) } if _, err := extractJSONArray(`[{"a":"x"}`); err == nil || !strings.Contains(err.Error(), "incomplete") { t.Fatalf("expected incomplete array error, got %v", err) } -} - -func TestLLMExtractorExtractImageOnlyUserMessageSkipsGenerator(t *testing.T) { - generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator, 10) - - entries, err := extractor.Extract(context.Background(), []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewRemoteImagePart("https://example.com/pic.png")}}, - }) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if len(entries) != 0 { - t.Fatalf("len(entries) = %d, want 0", len(entries)) + if _, err := extractJSONObject("no json here"); err == nil || !strings.Contains(err.Error(), "does not contain") { + t.Fatalf("expected missing object error, got %v", err) } - if generator.calls != 0 { - t.Fatalf("Generate() calls = %d, want 0", generator.calls) + if _, err := extractJSONObject(`{"a":"x"`); err == nil || !strings.Contains(err.Error(), "incomplete") { + t.Fatalf("expected incomplete object error, got %v", err) } } diff --git a/internal/memo/semantic_candidates.go b/internal/memo/semantic_candidates.go new file mode 100644 index 00000000..01f911ca --- /dev/null +++ b/internal/memo/semantic_candidates.go @@ -0,0 +1,321 @@ +package memo + +import ( + "context" + "fmt" + "sort" + "strings" + "unicode" +) + +const ( + semanticCandidateShortlistLimit = 5 + semanticCandidateContentMaxRunes = 240 +) + +type scoredExtractionCandidate struct { + candidate ExtractionCandidate + score int +} + +// semanticCandidateShortlist 为候选记忆检索相关 existing memo,并限制进入 LLM 的数量。 +func (s *Service) semanticCandidateShortlist( + ctx context.Context, + entry Entry, + limit int, +) ([]ExtractionCandidate, error) { + if s == nil { + return nil, nil + } + if limit <= 0 { + limit = semanticCandidateShortlistLimit + } + if err := s.ensureSemanticCandidateIndex(ctx); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + scored := make([]scoredExtractionCandidate, 0, len(s.semanticCandidatesByRef)) + for _, candidate := range s.semanticCandidatesByRef { + score := scoreExtractionCandidate(entry, candidate) + if score <= 0 { + continue + } + scored = append(scored, scoredExtractionCandidate{ + candidate: cloneExtractionCandidate(candidate), + score: score, + }) + } + sort.SliceStable(scored, func(i, j int) bool { + if scored[i].score != scored[j].score { + return scored[i].score > scored[j].score + } + return scored[i].candidate.Ref < scored[j].candidate.Ref + }) + if len(scored) > limit { + scored = scored[:limit] + } + + result := make([]ExtractionCandidate, 0, len(scored)) + for _, item := range scored { + result = append(result, item.candidate) + } + return result, nil +} + +// ensureSemanticCandidateIndex 懒加载语义去重候选索引,避免每轮 accepted run 都重扫 topic 文件。 +func (s *Service) ensureSemanticCandidateIndex(ctx context.Context) error { + s.mu.Lock() + if s.semanticIndexReady { + s.mu.Unlock() + return nil + } + s.mu.Unlock() + + s.semanticIndexMu.Lock() + defer s.semanticIndexMu.Unlock() + + s.mu.Lock() + if s.semanticIndexReady { + s.mu.Unlock() + return nil + } + s.mu.Unlock() + + candidatesByRef := make(map[string]ExtractionCandidate) + for _, scope := range supportedStorageScopes() { + index, err := s.store.LoadIndex(ctx, scope) + if err != nil { + return fmt.Errorf("memo: load index: %w", err) + } + for _, entry := range index.Entries { + topicFile := strings.TrimSpace(entry.TopicFile) + if topicFile == "" { + continue + } + topicContent, err := s.store.LoadTopic(ctx, scope, topicFile) + if err != nil { + continue + } + candidate := buildExtractionCandidateFromTopic(scope, entry, topicContent) + if candidate.Ref == "" { + continue + } + candidatesByRef[candidate.Ref] = candidate + } + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.semanticIndexReady { + return nil + } + s.semanticCandidatesByRef = candidatesByRef + s.semanticIndexReady = true + return nil +} + +// trackSemanticCandidateLocked 在语义索引已就绪时同步维护单条 memo 的 shortlist 快照。 +func (s *Service) trackSemanticCandidateLocked(scope Scope, entry Entry) { + if !s.semanticIndexReady { + return + } + topicFile := strings.TrimSpace(entry.TopicFile) + if topicFile == "" { + return + } + candidate := buildExtractionCandidateFromEntry(scope, entry) + if candidate.Ref == "" { + return + } + if s.semanticCandidatesByRef == nil { + s.semanticCandidatesByRef = make(map[string]ExtractionCandidate) + } + s.semanticCandidatesByRef[candidate.Ref] = candidate +} + +// removeSemanticCandidateLocked 从语义索引中删除指定 topic 的候选快照。 +func (s *Service) removeSemanticCandidateLocked(scope Scope, topicFile string) { + if !s.semanticIndexReady { + return + } + topicFile = strings.TrimSpace(topicFile) + if topicFile == "" { + return + } + delete(s.semanticCandidatesByRef, scopedTopicKey(scope, topicFile)) +} + +// buildExtractionCandidateFromEntry 将内存中的完整 Entry 收敛为 shortlist 快照。 +func buildExtractionCandidateFromEntry(scope Scope, entry Entry) ExtractionCandidate { + topicFile := strings.TrimSpace(entry.TopicFile) + if topicFile == "" { + return ExtractionCandidate{} + } + return ExtractionCandidate{ + Ref: scopedTopicKey(scope, topicFile), + Scope: scope, + Type: entry.Type, + Source: strings.TrimSpace(entry.Source), + Title: NormalizeTitle(entry.Title), + Keywords: normalizeKeywords(entry.Keywords), + Content: truncateSemanticContent(entry.Content), + } +} + +// buildExtractionCandidateFromTopic 将 topic 文件内容解析为 shortlist 快照。 +func buildExtractionCandidateFromTopic(scope Scope, entry Entry, topicContent string) ExtractionCandidate { + source, keywords, content := parseTopicSnapshot(topicContent) + entry.Source = source + entry.Keywords = keywords + entry.Content = content + return buildExtractionCandidateFromEntry(scope, entry) +} + +// parseTopicSnapshot 解析 topic frontmatter 中的 source、keywords 与正文内容。 +func parseTopicSnapshot(topic string) (string, []string, string) { + parts := strings.Split(topic, "---") + if len(parts) < 3 { + return "", nil, strings.TrimSpace(topic) + } + + var ( + source string + keywords []string + ) + frontmatter := parts[1] + body := strings.TrimSpace(strings.Join(parts[2:], "---")) + for _, line := range strings.Split(frontmatter, "\n") { + line = strings.TrimSpace(line) + switch { + case strings.HasPrefix(line, "source:"): + source = strings.TrimSpace(strings.TrimPrefix(line, "source:")) + case strings.HasPrefix(line, "keywords:"): + raw := strings.TrimSpace(strings.TrimPrefix(line, "keywords:")) + keywords = parseTopicKeywords(raw) + } + } + return source, keywords, body +} + +// parseTopicKeywords 解析 frontmatter 中的单行 keywords 列表。 +func parseTopicKeywords(raw string) []string { + raw = strings.TrimSpace(raw) + raw = strings.TrimPrefix(raw, "[") + raw = strings.TrimSuffix(raw, "]") + if raw == "" { + return nil + } + + parts := strings.Split(raw, ",") + keywords := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(strings.Trim(part, `"'`)) + if part == "" { + continue + } + keywords = append(keywords, part) + } + return normalizeKeywords(keywords) +} + +// truncateSemanticContent 对候选正文做固定上限截断,避免把全量 memo 内容送入 prompt。 +func truncateSemanticContent(content string) string { + content = strings.TrimSpace(content) + if content == "" { + return "" + } + runes := []rune(content) + if len(runes) <= semanticCandidateContentMaxRunes { + return content + } + return string(runes[:semanticCandidateContentMaxRunes]) + "..." +} + +// scoreExtractionCandidate 为 shortlist 候选打分,优先保留最相关的 memo。 +func scoreExtractionCandidate(target Entry, existing ExtractionCandidate) int { + targetTitle := strings.ToLower(NormalizeTitle(target.Title)) + targetContent := strings.ToLower(strings.TrimSpace(target.Content)) + existingTitle := strings.ToLower(NormalizeTitle(existing.Title)) + existingContent := strings.ToLower(strings.TrimSpace(existing.Content)) + + score := 0 + if target.Type == existing.Type { + score += 80 + } + if targetTitle != "" && targetTitle == existingTitle { + score += 1000 + } + if targetContent != "" && targetContent == existingContent { + score += 900 + } + if targetTitle != "" && strings.Contains(existingContent, targetTitle) { + score += 120 + } + if existingTitle != "" && strings.Contains(targetContent, existingTitle) { + score += 120 + } + + targetKeywordSet := tokenSet(append([]string{target.Title, target.Content}, target.Keywords...)...) + existingKeywordSet := tokenSet(append([]string{existing.Title, existing.Content}, existing.Keywords...)...) + for token := range targetKeywordSet { + if _, ok := existingKeywordSet[token]; ok { + score += 10 + } + } + for _, keyword := range normalizeKeywords(target.Keywords) { + normalized := strings.ToLower(strings.TrimSpace(keyword)) + if normalized == "" { + continue + } + for _, existingKeyword := range existing.Keywords { + if normalized == strings.ToLower(strings.TrimSpace(existingKeyword)) { + score += 40 + break + } + } + } + return score +} + +// tokenSet 将多段文本规整为去重后的 token 集合,供 shortlist 相关性排序使用。 +func tokenSet(parts ...string) map[string]struct{} { + set := make(map[string]struct{}) + for _, part := range parts { + for _, token := range tokenizeSemanticText(part) { + set[token] = struct{}{} + } + } + return set +} + +// tokenizeSemanticText 按字母数字边界切分文本,生成 shortlist 排序所需的归一化 token。 +func tokenizeSemanticText(text string) []string { + text = strings.TrimSpace(strings.ToLower(text)) + if text == "" { + return nil + } + parts := strings.FieldsFunc(text, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + }) + tokens := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + tokens = append(tokens, part) + } + return tokens +} + +// cloneExtractionCandidate 深拷贝 shortlist 候选,避免切片字段共享底层数组。 +func cloneExtractionCandidate(candidate ExtractionCandidate) ExtractionCandidate { + cloned := candidate + if len(candidate.Keywords) > 0 { + cloned.Keywords = append([]string(nil), candidate.Keywords...) + } + return cloned +} diff --git a/internal/memo/service.go b/internal/memo/service.go index 42551f26..2e30e0a3 100644 --- a/internal/memo/service.go +++ b/internal/memo/service.go @@ -14,14 +14,17 @@ import ( // Service 编排记忆的存储、检索、删除和索引维护,是 memo 子系统对外的统一入口。 type Service struct { - store Store - config config.MemoConfig - mu sync.Mutex - sourceInvl func() - autoExtractIndexMu sync.Mutex - autoExtractIndexReady bool - autoExtractKeysByTopic map[string]string - autoExtractKeyRefs map[string]int + store Store + config config.MemoConfig + mu sync.Mutex + sourceInvl func() + autoExtractIndexMu sync.Mutex + autoExtractIndexReady bool + autoExtractKeysByTopic map[string]string + autoExtractKeyRefs map[string]int + semanticIndexMu sync.Mutex + semanticIndexReady bool + semanticCandidatesByRef map[string]ExtractionCandidate } // NewService 创建 memo Service 实例。 @@ -187,6 +190,15 @@ func (s *Service) updateAutoExtractIfAllowed(ctx context.Context, ref string, ne s.trackAutoExtractEntryLocked(scope, current) } } + if s.semanticIndexReady { + s.removeSemanticCandidateLocked(scope, topicFile) + for _, removed := range removedEntries { + s.removeSemanticCandidateLocked(scope, removed.TopicFile) + } + if indexContainsTopicFile(working, topicFile) { + s.trackSemanticCandidateLocked(scope, current) + } + } s.invalidateCache() return true, nil @@ -239,6 +251,7 @@ func (s *Service) Remove(ctx context.Context, keyword string, scope Scope) (int, if topicFile := strings.TrimSpace(entry.TopicFile); topicFile != "" { _ = s.store.DeleteTopic(ctx, bucket, topicFile) s.removeAutoExtractTopicLocked(bucket, topicFile) + s.removeSemanticCandidateLocked(bucket, topicFile) } } removed += len(removedEntries) @@ -393,6 +406,17 @@ func (s *Service) saveEntryLocked(ctx context.Context, entry Entry) error { s.trackAutoExtractEntryLocked(scope, entry) } } + if s.semanticIndexReady { + if replaced { + s.removeSemanticCandidateLocked(scope, previous.TopicFile) + } + for _, removed := range removedEntries { + s.removeSemanticCandidateLocked(scope, removed.TopicFile) + } + if indexContainsEntryID(working, entry.ID) { + s.trackSemanticCandidateLocked(scope, entry) + } + } s.invalidateCache() return nil diff --git a/internal/memo/service_test.go b/internal/memo/service_test.go index 988b71ff..462d17f3 100644 --- a/internal/memo/service_test.go +++ b/internal/memo/service_test.go @@ -3,6 +3,7 @@ package memo import ( "context" "errors" + "fmt" "strings" "testing" @@ -11,10 +12,9 @@ import ( func testMemoConfig() config.MemoConfig { return config.MemoConfig{ - MaxEntries: 200, - MaxIndexBytes: 16 * 1024, - ExtractTimeoutSec: 15, - ExtractRecentMessages: 10, + MaxEntries: 200, + MaxIndexBytes: 16 * 1024, + ExtractTimeoutSec: 15, } } @@ -300,10 +300,9 @@ func TestServiceAutoExtractDedupAcrossScopes(t *testing.T) { func TestServiceAutoExtractTrimmedEntryDoesNotPolluteDedupIndex(t *testing.T) { svc := NewService(newMemoryTestStore(), config.MemoConfig{ - MaxEntries: 10, - MaxIndexBytes: 1, - ExtractTimeoutSec: 15, - ExtractRecentMessages: 10, + MaxEntries: 10, + MaxIndexBytes: 1, + ExtractTimeoutSec: 15, }, nil) entry := Entry{ Type: TypeUser, @@ -375,6 +374,82 @@ func TestMatchesKeywordIncludesContent(t *testing.T) { } } +func TestServiceSemanticCandidateShortlistMatchesRelevantMemo(t *testing.T) { + store := NewFileStore(t.TempDir(), t.TempDir()) + svc := NewService(store, testMemoConfig(), nil) + + if err := svc.Add(context.Background(), Entry{ + Type: TypeFeedback, + Title: "测试策略", + Content: "用户要求修改后先跑测试。", + Source: SourceAutoExtract, + }); err != nil { + t.Fatalf("seed relevant Add() error = %v", err) + } + if err := svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "语言偏好", + Content: "用户偏好中文回复。", + Source: SourceUserManual, + }); err != nil { + t.Fatalf("seed unrelated Add() error = %v", err) + } + + shortlist, err := svc.semanticCandidateShortlist(context.Background(), Entry{ + Type: TypeFeedback, + Title: "测试策略", + Content: "用户要求修改后先跑相关测试。", + }, semanticCandidateShortlistLimit) + if err != nil { + t.Fatalf("semanticCandidateShortlist() error = %v", err) + } + if len(shortlist) != 1 { + t.Fatalf("len(shortlist) = %d, want 1", len(shortlist)) + } + if shortlist[0].Title != "测试策略" || shortlist[0].Source != SourceAutoExtract { + t.Fatalf("unexpected shortlist candidate: %+v", shortlist[0]) + } +} + +func TestServiceSemanticCandidateShortlistLimitsAndTruncates(t *testing.T) { + store := NewFileStore(t.TempDir(), t.TempDir()) + svc := NewService(store, testMemoConfig(), nil) + + longContent := strings.Repeat("release-note ", 40) + for index := 0; index < 6; index++ { + if err := svc.Add(context.Background(), Entry{ + Type: TypeProject, + Title: fmt.Sprintf("发布计划 %d", index), + Content: longContent + fmt.Sprintf(" milestone-%d", index), + Keywords: []string{"release"}, + Source: SourceAutoExtract, + }); err != nil { + t.Fatalf("seed Add(%d) error = %v", index, err) + } + } + + shortlist, err := svc.semanticCandidateShortlist(context.Background(), Entry{ + Type: TypeProject, + Title: "发布计划", + Content: "release milestone", + Keywords: []string{"release"}, + }, semanticCandidateShortlistLimit) + if err != nil { + t.Fatalf("semanticCandidateShortlist() error = %v", err) + } + if len(shortlist) != semanticCandidateShortlistLimit { + t.Fatalf("len(shortlist) = %d, want %d", len(shortlist), semanticCandidateShortlistLimit) + } + for _, candidate := range shortlist { + if len([]rune(candidate.Content)) > semanticCandidateContentMaxRunes+3 { + t.Fatalf("candidate content should be truncated, got %q", candidate.Content) + } + if len(candidate.Keywords) != 1 || candidate.Keywords[0] != "release" { + t.Fatalf("candidate keywords should be preserved, got %+v", candidate) + } + } +} + func TestTrimIndexEntriesByBytesRemovesMinimalPrefix(t *testing.T) { index := &Index{ Entries: []Entry{ diff --git a/internal/memo/types.go b/internal/memo/types.go index 0d1a62a0..04d2d084 100644 --- a/internal/memo/types.go +++ b/internal/memo/types.go @@ -103,12 +103,13 @@ const ( // ExtractionCandidate 表示提供给模型做语义去重的既有记忆快照。 type ExtractionCandidate struct { - Ref string `json:"ref"` - Scope Scope `json:"scope"` - Type Type `json:"type"` - Source string `json:"source"` - Title string `json:"title"` - Content string `json:"content"` + Ref string `json:"ref"` + Scope Scope `json:"scope"` + Type Type `json:"type"` + Source string `json:"source"` + Title string `json:"title"` + Keywords []string `json:"keywords,omitempty"` + Content string `json:"content"` } // ExtractionDecision 表示模型针对新旧记忆关系返回的结构化决策。 @@ -133,13 +134,13 @@ type Extractor interface { Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) } -// DecisionExtractor 定义带既有记忆快照的语义提取能力。 -type DecisionExtractor interface { - ExtractDecisions( +// DecisionResolver 定义针对单条候选记忆做去重决策的最小能力。 +type DecisionResolver interface { + ResolveDecision( ctx context.Context, - messages []providertypes.Message, + candidate Entry, existing []ExtractionCandidate, - ) ([]ExtractionDecision, error) + ) (ExtractionDecision, error) } // TextGenerator 定义调用 LLM 生成文本的最小能力,用于记忆提取。 diff --git a/internal/tools/memo/remember_test.go b/internal/tools/memo/remember_test.go index ac0d39f8..bf4defef 100644 --- a/internal/tools/memo/remember_test.go +++ b/internal/tools/memo/remember_test.go @@ -15,10 +15,9 @@ func newTestService(t *testing.T) *memo.Service { t.Helper() store := memo.NewFileStore(t.TempDir(), t.TempDir()) return memo.NewService(store, config.MemoConfig{ - MaxEntries: 200, - MaxIndexBytes: 16 * 1024, - ExtractTimeoutSec: 15, - ExtractRecentMessages: 10, + MaxEntries: 200, + MaxIndexBytes: 16 * 1024, + ExtractTimeoutSec: 15, }, nil) }