From 36fd1d58c79906a2f0b28021d149a54427cc1f60 Mon Sep 17 00:00:00 2001
From: Yumiue <229866007@qq.com>
Date: Sun, 10 May 2026 10:36:56 +0800
Subject: [PATCH 1/5] =?UTF-8?q?fix(memo):memo=E5=88=9D=E6=AD=A5=E4=BF=AE?=
=?UTF-8?q?=E7=90=86,run=E6=88=AA=E5=8F=96=EF=BC=8Cllm=E6=8F=90=E5=8F=96?=
=?UTF-8?q?=E6=97=B6=E5=8E=BB=E9=87=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
internal/app/bootstrap_test.go | 42 +++++++
internal/context/projection.go | 36 ++++++
internal/context/projection_test.go | 44 ++++++++
internal/memo/auto_extractor.go | 64 +++++++++++
internal/memo/auto_extractor_test.go | 144 ++++++++++++++++++++++++
internal/memo/llm_extractor.go | 123 ++++++++++++++++++---
internal/memo/llm_extractor_test.go | 61 +++++++++--
internal/memo/service.go | 158 +++++++++++++++++++++++++++
internal/memo/types.go | 38 +++++++
internal/runtime/memo.go | 24 ++++
internal/runtime/run.go | 2 +-
internal/runtime/runtime.go | 2 +-
internal/runtime/runtime_test.go | 58 ++++++++++
internal/runtime/session_mutation.go | 23 +++-
internal/runtime/state.go | 2 +
15 files changed, 793 insertions(+), 28 deletions(-)
diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go
index 3760dded..917823cb 100644
--- a/internal/app/bootstrap_test.go
+++ b/internal/app/bootstrap_test.go
@@ -1675,6 +1675,48 @@ func TestNewMemoExtractorAdapterBuildsProviderSafeMemoWindow(t *testing.T) {
}
}
+func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) {
+ t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token")
+ cfg := config.StaticDefaults().Clone()
+ cfg.SelectedProvider = config.OpenAIName
+ cfg.Memo.ExtractRecentMessages = 3
+ manager := config.NewManager(config.NewLoader("", &cfg))
+
+ providerStub := &stubMemoProvider{
+ generate: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error {
+ if len(req.Messages) != 12 {
+ t.Fatalf("unexpected memo window length %d, want full run: %+v", len(req.Messages), req.Messages)
+ }
+ events <- providertypes.NewTextDeltaStreamEvent(`[]`)
+ events <- providertypes.NewMessageDoneStreamEvent("stop", nil)
+ return nil
+ },
+ }
+ factory := &stubMemoProviderFactory{provider: providerStub}
+ scheduler := &stubMemoExtractorScheduler{}
+ extractor := newMemoExtractorAdapter(factory, manager, scheduler)
+
+ inputMessages := make([]providertypes.Message, 0, 12)
+ for index := 0; index < 12; index++ {
+ inputMessages = append(inputMessages, providertypes.Message{
+ Role: providertypes.RoleUser,
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart(fmt.Sprintf("message-%02d", index))},
+ })
+ }
+ extractor.Schedule("session-1", inputMessages)
+ if !scheduler.called || scheduler.extractor == nil {
+ t.Fatalf("expected scheduler to receive extractor")
+ }
+
+ _, err := scheduler.extractor.Extract(context.Background(), inputMessages)
+ if err != nil {
+ t.Fatalf("extractor.Extract() error = %v", err)
+ }
+ if !factory.called {
+ t.Fatalf("expected provider factory Build to be called")
+ }
+}
+
func TestNewMemoExtractorAdapterKeepsScheduledConfigSnapshot(t *testing.T) {
t.Setenv(config.OpenAIDefaultAPIKeyEnv, "openai-token")
t.Setenv(config.QiniuDefaultAPIKeyEnv, "qiniu-token")
diff --git a/internal/context/projection.go b/internal/context/projection.go
index ac7da4b3..298f7e8c 100644
--- a/internal/context/projection.go
+++ b/internal/context/projection.go
@@ -79,6 +79,42 @@ func BuildRecentMessagesForModel(messages []providertypes.Message, limit int) []
return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected)))
}
+// BuildMemoExtractionMessagesForModel 构造完整 run 的 provider-safe 记忆提取上下文。
+func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []providertypes.Message {
+ if len(messages) == 0 {
+ return nil
+ }
+
+ keep := make([]bool, len(messages))
+ for index := 0; index < len(messages); index++ {
+ message := messages[index]
+ if message.Role == providertypes.RoleTool {
+ continue
+ }
+
+ if message.Role == providertypes.RoleAssistant && len(message.ToolCalls) > 0 {
+ for _, spanIndex := range matchedToolCallSpan(messages, index) {
+ keep[spanIndex] = true
+ }
+ continue
+ }
+
+ keep[index] = true
+ }
+
+ selected := make([]providertypes.Message, 0, len(messages))
+ for index, message := range messages {
+ if keep[index] {
+ selected = append(selected, message)
+ }
+ }
+ if len(selected) == 0 {
+ return nil
+ }
+
+ return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected)))
+}
+
// matchedToolCallSpan 返回 assistant tool call 与其完整 tool 响应组成的合法窗口下标集合。
func matchedToolCallSpan(messages []providertypes.Message, assistantIndex int) []int {
if assistantIndex < 0 || assistantIndex >= len(messages) {
diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go
index 372e3c3f..3b0c932a 100644
--- a/internal/context/projection_test.go
+++ b/internal/context/projection_test.go
@@ -140,6 +140,50 @@ func TestBuildRecentMessagesForModelKeepsOnlyRecentValidAnchors(t *testing.T) {
}
}
+func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T) {
+ t.Parallel()
+
+ messages := []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}},
+ {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan")}},
+ {
+ Role: providertypes.RoleAssistant,
+ ToolCalls: []providertypes.ToolCall{
+ {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
+ },
+ },
+ {
+ Role: providertypes.RoleTool,
+ ToolCallID: "call-1",
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")},
+ ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"},
+ },
+ {
+ Role: providertypes.RoleAssistant,
+ ToolCalls: []providertypes.ToolCall{
+ {ID: "call-missing", Name: "bash", Arguments: `{}`},
+ },
+ },
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("last")}},
+ }
+
+ projected := BuildMemoExtractionMessagesForModel(messages)
+ if len(projected) != 4 {
+ t.Fatalf("len(projected) = %d, want 4: %+v", len(projected), projected)
+ }
+ if renderDisplayParts(projected[0].Parts) != "first" || renderDisplayParts(projected[3].Parts) != "last" {
+ t.Fatalf("expected full run user messages to remain, got %+v", projected)
+ }
+ if projected[1].Role != providertypes.RoleAssistant || len(projected[1].ToolCalls) != 1 {
+ t.Fatalf("expected complete assistant tool span, got %+v", projected[1])
+ }
+ if projected[2].Role != providertypes.RoleTool ||
+ !strings.Contains(renderDisplayParts(projected[2].Parts), "tool result") ||
+ projected[2].ToolMetadata != nil {
+ t.Fatalf("expected projected tool result, got %+v", projected[2])
+ }
+}
+
func TestBuildRecentMessagesForModelRespectsAbsoluteMessageBudget(t *testing.T) {
t.Parallel()
diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go
index aaadfec9..a1ccb43b 100644
--- a/internal/memo/auto_extractor.go
+++ b/internal/memo/auto_extractor.go
@@ -225,6 +225,24 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider
ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout)
defer cancel()
+ if decisionExtractor, ok := extractor.(DecisionExtractor); ok {
+ existing, err := a.svc.autoExtractionCandidates(ctx)
+ if err != nil {
+ a.logError("memo: auto extract load candidates failed: %v", err)
+ return false
+ }
+ decisions, err := decisionExtractor.ExtractDecisions(ctx, messages, existing)
+ if err != nil {
+ if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
+ a.logError("memo: auto extract skipped (protocol_mismatch): %v", err)
+ return true
+ }
+ a.logError("memo: auto extract failed: %v", err)
+ return false
+ }
+ return a.applyExtractionDecisions(ctx, decisions)
+ }
+
entries, err := extractor.Extract(ctx, messages)
if err != nil {
if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
@@ -265,6 +283,52 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider
return succeeded
}
+// applyExtractionDecisions 应用模型返回的 create/update/skip 决策,并保留本地精确去重兜底。
+func (a *AutoExtractor) applyExtractionDecisions(ctx context.Context, decisions []ExtractionDecision) bool {
+ if len(decisions) == 0 {
+ return true
+ }
+
+ seenCreates := make(map[string]struct{}, len(decisions))
+ succeeded := true
+ for _, decision := range decisions {
+ switch decision.Action {
+ case ExtractionActionCreate:
+ entry := decision.Entry
+ entry.Source = SourceAutoExtract
+ key := autoExtractDedupKey(entry)
+ if key == "" {
+ continue
+ }
+ if _, exists := seenCreates[key]; exists {
+ continue
+ }
+ added, err := a.svc.addAutoExtractIfAbsent(ctx, entry)
+ if err != nil {
+ a.logError("memo: auto extract add failed: %v", err)
+ succeeded = false
+ continue
+ }
+ seenCreates[key] = struct{}{}
+ if !added {
+ continue
+ }
+ case ExtractionActionUpdate:
+ entry := decision.Entry
+ entry.Source = SourceAutoExtract
+ _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, entry)
+ if err != nil {
+ a.logError("memo: auto extract update failed: %v", err)
+ succeeded = false
+ continue
+ }
+ case ExtractionActionSkip:
+ continue
+ }
+ }
+ return succeeded
+}
+
// autoExtractDedupKey 生成自动提取条目的精确去重键。
func autoExtractDedupKey(entry Entry) string {
title := NormalizeTitle(entry.Title)
diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go
index 2f188d07..6c4e5835 100644
--- a/internal/memo/auto_extractor_test.go
+++ b/internal/memo/auto_extractor_test.go
@@ -38,6 +38,33 @@ func (s *stubMemoExtractor) Calls() int {
return s.callCount
}
+type stubDecisionMemoExtractor struct {
+ mu sync.Mutex
+ callCount int
+ candidates []ExtractionCandidate
+ decisions []ExtractionDecision
+ err error
+}
+
+func (s *stubDecisionMemoExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) {
+ return nil, nil
+}
+
+func (s *stubDecisionMemoExtractor) ExtractDecisions(
+ ctx context.Context,
+ messages []providertypes.Message,
+ existing []ExtractionCandidate,
+) ([]ExtractionDecision, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.callCount++
+ s.candidates = append([]ExtractionCandidate(nil), existing...)
+ if s.err != nil {
+ return nil, s.err
+ }
+ return append([]ExtractionDecision(nil), s.decisions...), nil
+}
+
func newAutoExtractorTestService(t *testing.T) *Service {
t.Helper()
store := NewFileStore(t.TempDir(), t.TempDir())
@@ -214,6 +241,123 @@ func TestAutoExtractorSuppressesExactDuplicates(t *testing.T) {
})
}
+func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing.T) {
+ svc := newAutoExtractorTestService(t)
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeFeedback,
+ Title: "测试策略",
+ Content: "用户要求修改后先跑测试。",
+ Source: SourceAutoExtract,
+ }); err != nil {
+ t.Fatalf("seed auto Add() error = %v", err)
+ }
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeUser,
+ Title: "语言偏好",
+ Content: "用户手动保存偏好中文回复。",
+ Source: SourceUserManual,
+ }); err != nil {
+ t.Fatalf("seed manual Add() error = %v", err)
+ }
+
+ candidates, err := svc.autoExtractionCandidates(context.Background())
+ if err != nil {
+ t.Fatalf("autoExtractionCandidates() error = %v", err)
+ }
+ var autoRef, manualRef string
+ for _, candidate := range candidates {
+ switch candidate.Source {
+ case SourceAutoExtract:
+ autoRef = candidate.Ref
+ case SourceUserManual:
+ manualRef = candidate.Ref
+ }
+ }
+ if autoRef == "" || manualRef == "" {
+ t.Fatalf("expected refs, auto=%q manual=%q candidates=%+v", autoRef, manualRef, candidates)
+ }
+
+ extractor := &stubDecisionMemoExtractor{
+ decisions: []ExtractionDecision{
+ {
+ Action: ExtractionActionUpdate,
+ Ref: autoRef,
+ Entry: Entry{
+ Title: "测试策略",
+ Content: "用户要求修改后先跑相关测试。",
+ Keywords: []string{"test"},
+ },
+ },
+ {
+ Action: ExtractionActionUpdate,
+ Ref: manualRef,
+ Entry: Entry{
+ Title: "语言偏好",
+ Content: "不应覆盖手动记忆。",
+ },
+ },
+ },
+ }
+ auto := NewAutoExtractor(extractor, svc, time.Second)
+ auto.debounce = 5 * time.Millisecond
+ auto.logf = func(string, ...any) {}
+ registerAutoExtractorCleanup(t, auto)
+
+ auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("更新测试策略")}}})
+ waitFor(t, time.Second, func() bool {
+ extractor.mu.Lock()
+ defer extractor.mu.Unlock()
+ return extractor.callCount == 1
+ })
+
+ recall, err := svc.Recall(context.Background(), "测试策略", ScopeProject)
+ if err != nil {
+ t.Fatalf("Recall(auto) error = %v", err)
+ }
+ if len(recall) != 1 || !strings.Contains(recall[0].Content, "相关测试") {
+ t.Fatalf("expected auto memory to be updated, got %+v", recall)
+ }
+ manualRecall, err := svc.Recall(context.Background(), "语言偏好", ScopeUser)
+ if err != nil {
+ t.Fatalf("Recall(manual) error = %v", err)
+ }
+ if len(manualRecall) != 1 || strings.Contains(manualRecall[0].Content, "不应覆盖") {
+ t.Fatalf("manual memory should not be overwritten, got %+v", manualRecall)
+ }
+ if len(extractor.candidates) != 2 {
+ t.Fatalf("expected existing candidates to be provided, got %+v", extractor.candidates)
+ }
+}
+
+func TestAutoExtractorSemanticCreateStillUsesExactDedup(t *testing.T) {
+ svc := newAutoExtractorTestService(t)
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeUser,
+ Title: "中文回复",
+ Content: "用户偏好中文回复。",
+ Source: SourceAutoExtract,
+ }); err != nil {
+ t.Fatalf("seed Add() error = %v", err)
+ }
+
+ extractor := &stubDecisionMemoExtractor{
+ decisions: []ExtractionDecision{
+ {Action: ExtractionActionCreate, Entry: Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}},
+ {Action: ExtractionActionCreate, Entry: Entry{Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}},
+ },
+ }
+ auto := NewAutoExtractor(extractor, svc, time.Second)
+ auto.debounce = 5 * time.Millisecond
+ auto.logf = func(string, ...any) {}
+ registerAutoExtractorCleanup(t, auto)
+
+ auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("dedupe")}}})
+ waitFor(t, time.Second, func() bool {
+ entries, err := svc.List(context.Background(), ScopeAll)
+ return err == nil && len(entries) == 2
+ })
+}
+
func TestAutoExtractorUsesTimeoutContext(t *testing.T) {
svc := newAutoExtractorTestService(t)
extractor := &stubMemoExtractor{
diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go
index 9d039f44..621446de 100644
--- a/internal/memo/llm_extractor.go
+++ b/internal/memo/llm_extractor.go
@@ -27,6 +27,8 @@ type LLMExtractor struct {
}
type extractedEntry struct {
+ Action string `json:"action"`
+ Ref string `json:"ref"`
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
@@ -45,8 +47,27 @@ func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtrac
}
}
-// Extract 从最近对话中提取可跨会话持久化的记忆条目。
+// Extract 从当前 run 对话中提取可跨会话持久化的新增记忆条目。
func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) {
+ decisions, err := e.ExtractDecisions(ctx, messages, nil)
+ if err != nil {
+ return nil, err
+ }
+ entries := make([]Entry, 0, len(decisions))
+ for _, decision := range decisions {
+ if decision.Action == ExtractionActionCreate {
+ entries = append(entries, decision.Entry)
+ }
+ }
+ return entries, nil
+}
+
+// ExtractDecisions 从当前 run 对话中提取记忆,并结合既有记忆输出新增、合并或跳过决策。
+func (e *LLMExtractor) ExtractDecisions(
+ ctx context.Context,
+ messages []providertypes.Message,
+ existing []ExtractionCandidate,
+) ([]ExtractionDecision, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
@@ -54,12 +75,12 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes
return nil, errors.New("memo: text generator is nil")
}
- recent := agentcontext.BuildRecentMessagesForModel(messages, e.recentMessageLimit)
- if len(recent) == 0 || !containsUserMessage(recent) {
+ runMessages := agentcontext.BuildMemoExtractionMessagesForModel(messages)
+ if len(runMessages) == 0 || !containsUserMessage(runMessages) {
return nil, nil
}
- response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now()), recent)
+ response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now(), existing), runMessages)
if err != nil {
return nil, err
}
@@ -77,23 +98,29 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes
return nil, fmt.Errorf("memo: parse extraction response: %w", err)
}
- entries := make([]Entry, 0, len(extracted))
+ decisions := make([]ExtractionDecision, 0, len(extracted))
for _, item := range extracted {
- entry, ok := toMemoEntry(item)
+ decision, ok := toExtractionDecision(item)
if !ok {
continue
}
- entries = append(entries, entry)
+ decisions = append(decisions, decision)
}
- return entries, nil
+ return decisions, nil
}
// buildExtractionPrompt 构造记忆提取专用的 system prompt。
-func buildExtractionPrompt(now time.Time) string {
+func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string {
currentDate := now.In(time.Local).Format("2006-01-02")
+ existingJSON := "[]"
+ if len(existing) > 0 {
+ if data, err := json.Marshal(existing); err == nil {
+ existingJSON = string(data)
+ }
+ }
return strings.TrimSpace(fmt.Sprintf(`
你是一个记忆提取助手(memory extraction assistant)。
-分析最近对话中值得跨会话持久记住的信息,并返回严格 JSON 数组。
+分析当前 run 对话中值得跨会话持久记住的信息,并结合既有记忆完成语义去重,返回严格 JSON 数组。
当前本地日期:%s
如果对话中出现“今天、明天、下周二”等相对日期,必须先转换为绝对日期再写入 content。
@@ -109,11 +136,18 @@ func buildExtractionPrompt(now time.Time) string {
2. 不要提取通用编程知识、代码结构、文件路径、Git 历史。
3. 每条记忆必须具体、可操作。
4. 没有值得记住的信息时,返回 []。
-5. 输出必须是 JSON 数组,不要输出任何额外解释。
+5. 如果新信息与既有记忆语义相同或只是轻微改写,输出 action="skip"。
+6. 如果新信息能补充或修正既有 source="extractor_auto" 的记忆,输出 action="update" 并填写目标 ref。
+7. 不允许 update source 不是 "extractor_auto" 的既有记忆;这类相近内容只能 skip。
+8. 如果是全新的可持久化信息,输出 action="create"。
+9. 输出必须是 JSON 数组,不要输出任何额外解释。
+
+既有记忆候选(JSON):
+%s
输出格式:
-[{"type":"user","title":"...","content":"...","keywords":["..."]}]
-`, currentDate))
+[{"action":"create","type":"user","title":"...","content":"...","keywords":["..."]},{"action":"update","ref":"project:p.md","title":"...","content":"...","keywords":["..."]},{"action":"skip","ref":"user:u.md"}]
+`, currentDate, existingJSON))
}
// containsUserMessage 检查待提取消息中是否包含用户输入。
@@ -148,6 +182,69 @@ func toMemoEntry(item extractedEntry) (Entry, bool) {
}, true
}
+// toExtractionDecision 将 LLM 输出收敛为自动提取持久化决策。
+func toExtractionDecision(item extractedEntry) (ExtractionDecision, bool) {
+ action := parseExtractionAction(item.Action)
+ if action == "" {
+ return ExtractionDecision{}, false
+ }
+ if action == ExtractionActionSkip {
+ return ExtractionDecision{
+ Action: action,
+ Ref: strings.TrimSpace(item.Ref),
+ }, true
+ }
+
+ if action == ExtractionActionUpdate {
+ ref := strings.TrimSpace(item.Ref)
+ if ref == "" {
+ return ExtractionDecision{}, false
+ }
+ entry, ok := toMemoUpdateEntry(item)
+ if !ok {
+ return ExtractionDecision{}, false
+ }
+ return ExtractionDecision{Action: action, Ref: ref, Entry: entry}, true
+ }
+
+ entry, ok := toMemoEntry(item)
+ if !ok {
+ return ExtractionDecision{}, false
+ }
+ return ExtractionDecision{Action: action, Entry: entry}, true
+}
+
+// toMemoUpdateEntry 将 update 决策中的可变字段收敛为 Entry 片段。
+func toMemoUpdateEntry(item extractedEntry) (Entry, bool) {
+ title := NormalizeTitle(item.Title)
+ content := strings.TrimSpace(item.Content)
+ if title == "" || content == "" {
+ return Entry{}, false
+ }
+ return Entry{
+ Title: title,
+ Content: content,
+ Keywords: normalizeKeywords(item.Keywords),
+ Source: SourceAutoExtract,
+ }, true
+}
+
+// parseExtractionAction 解析模型决策动作,并兼容旧格式中缺省 action 的 create 输出。
+func parseExtractionAction(action string) ExtractionAction {
+ switch ExtractionAction(strings.ToLower(strings.TrimSpace(action))) {
+ case "":
+ return ExtractionActionCreate
+ case ExtractionActionCreate:
+ return ExtractionActionCreate
+ case ExtractionActionUpdate:
+ return ExtractionActionUpdate
+ case ExtractionActionSkip:
+ return ExtractionActionSkip
+ default:
+ return ""
+ }
+}
+
// normalizeKeywords 规范化关键词列表,移除空值和重复值。
func normalizeKeywords(keywords []string) []string {
if len(keywords) == 0 {
diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go
index 3ccf9ee9..06813229 100644
--- a/internal/memo/llm_extractor_test.go
+++ b/internal/memo/llm_extractor_test.go
@@ -196,8 +196,8 @@ func TestLLMExtractorExtractCancelledContext(t *testing.T) {
}
}
-// TestLLMExtractorExtractUsesRecentNonToolMessages 验证只取最近 10 条非 tool 消息。
-func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) {
+// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息而非固定 recent window。
+func TestLLMExtractorExtractUsesFullRunMessages(t *testing.T) {
generator := &stubTextGenerator{response: `[]`}
extractor := NewLLMExtractor(generator, 10)
@@ -219,19 +219,64 @@ func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) {
if err != nil {
t.Fatalf("Extract() error = %v", err)
}
- if len(generator.messages) != 10 {
- t.Fatalf("len(generator.messages) = %d, want 10", len(generator.messages))
+ if len(generator.messages) != 12 {
+ t.Fatalf("len(generator.messages) = %d, want 12", len(generator.messages))
}
for _, message := range generator.messages {
if message.Role == providertypes.RoleTool {
t.Fatalf("unexpected tool message in extraction context: %#v", message)
}
}
- if renderMemoParts(generator.messages[0].Parts) != "user-c" ||
- renderMemoParts(generator.messages[9].Parts) != "user-l" {
- t.Fatalf("unexpected recent window: first=%q last=%q",
+ if renderMemoParts(generator.messages[0].Parts) != "user-a" ||
+ renderMemoParts(generator.messages[11].Parts) != "user-l" {
+ t.Fatalf("unexpected run window: first=%q last=%q",
renderMemoParts(generator.messages[0].Parts),
- renderMemoParts(generator.messages[9].Parts))
+ renderMemoParts(generator.messages[11].Parts))
+ }
+}
+
+func TestLLMExtractorExtractDecisionsIncludesExistingCandidates(t *testing.T) {
+ generator := &stubTextGenerator{
+ response: `[{"action":"skip","ref":"user:u.md"},{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}]`,
+ }
+ extractor := NewLLMExtractor(generator, 10)
+
+ decisions, err := extractor.ExtractDecisions(
+ context.Background(),
+ []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑相关测试。")}},
+ },
+ []ExtractionCandidate{
+ {
+ Ref: "project:p.md",
+ Scope: ScopeProject,
+ Type: TypeFeedback,
+ Source: SourceAutoExtract,
+ Title: "测试策略",
+ Content: "用户要求修改后先跑测试。",
+ },
+ },
+ )
+ if err != nil {
+ t.Fatalf("ExtractDecisions() error = %v", err)
+ }
+ if len(decisions) != 2 {
+ t.Fatalf("len(decisions) = %d, want 2", len(decisions))
+ }
+ if decisions[0].Action != ExtractionActionSkip || decisions[0].Ref != "user:u.md" {
+ t.Fatalf("unexpected skip decision: %+v", decisions[0])
+ }
+ if decisions[1].Action != ExtractionActionUpdate || decisions[1].Ref != "project:p.md" {
+ t.Fatalf("unexpected update decision: %+v", decisions[1])
+ }
+ if decisions[1].Entry.Type != "" {
+ t.Fatalf("update decision should not require type, got %+v", decisions[1].Entry)
+ }
+ if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) {
+ t.Fatalf("prompt should include existing memory candidates, got %q", generator.prompt)
+ }
+ if !strings.Contains(generator.prompt, `action="update"`) || !strings.Contains(generator.prompt, `source="extractor_auto"`) {
+ t.Fatalf("prompt should describe semantic dedupe protocol, got %q", generator.prompt)
}
}
diff --git a/internal/memo/service.go b/internal/memo/service.go
index c7415c50..42551f26 100644
--- a/internal/memo/service.go
+++ b/internal/memo/service.go
@@ -68,6 +68,130 @@ func (s *Service) addAutoExtractIfAbsent(ctx context.Context, entry Entry) (bool
return true, nil
}
+// autoExtractionCandidates 加载既有记忆快照,供模型在提取时做语义去重判断。
+func (s *Service) autoExtractionCandidates(ctx context.Context) ([]ExtractionCandidate, error) {
+ if s == nil {
+ return nil, nil
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ candidates := make([]ExtractionCandidate, 0)
+ for _, scope := range supportedStorageScopes() {
+ index, err := s.loadIndexLocked(ctx, scope)
+ if err != nil {
+ return nil, err
+ }
+ for _, entry := range index.Entries {
+ topicFile := strings.TrimSpace(entry.TopicFile)
+ if topicFile == "" {
+ continue
+ }
+ topicContent, err := s.store.LoadTopic(ctx, scope, topicFile)
+ if err != nil {
+ continue
+ }
+ source, content := parseTopicSourceAndContent(topicContent)
+ candidates = append(candidates, ExtractionCandidate{
+ Ref: scopedTopicKey(scope, topicFile),
+ Scope: scope,
+ Type: entry.Type,
+ Source: source,
+ Title: entry.Title,
+ Content: content,
+ })
+ }
+ }
+ return candidates, nil
+}
+
+// updateAutoExtractIfAllowed 按 ref 更新既有自动提取记忆,显式记忆不会被后台流程覆盖。
+func (s *Service) updateAutoExtractIfAllowed(ctx context.Context, ref string, next Entry) (bool, error) {
+ next.Title = NormalizeTitle(next.Title)
+ next.Content = strings.TrimSpace(next.Content)
+ next.Keywords = normalizeKeywords(next.Keywords)
+ if next.Title == "" {
+ return false, fmt.Errorf("memo: title is empty")
+ }
+ if next.Content == "" {
+ return false, fmt.Errorf("memo: content is empty")
+ }
+ if err := s.ensureAutoExtractIndex(ctx); err != nil {
+ return false, err
+ }
+
+ scope, topicFile, ok := parseScopedTopicKey(ref)
+ if !ok {
+ return false, nil
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ index, err := s.loadIndexLocked(ctx, scope)
+ if err != nil {
+ return false, err
+ }
+ working := cloneIndex(index)
+
+ targetIndex := -1
+ var current Entry
+ for idx, existing := range working.Entries {
+ if strings.TrimSpace(existing.TopicFile) == topicFile {
+ targetIndex = idx
+ current = existing
+ break
+ }
+ }
+ if targetIndex < 0 {
+ return false, nil
+ }
+
+ topicContent, err := s.store.LoadTopic(ctx, scope, topicFile)
+ if err != nil {
+ return false, nil
+ }
+ source, _ := parseTopicSourceAndContent(topicContent)
+ if source != SourceAutoExtract {
+ return false, nil
+ }
+
+ current.Title = next.Title
+ current.Content = next.Content
+ current.Keywords = next.Keywords
+ current.Source = SourceAutoExtract
+ current.TopicFile = topicFile
+ working.Entries[targetIndex] = current
+ working.UpdatedAt = time.Now()
+
+ removedEntries := trimIndexEntries(working, s.config.MaxEntries, s.config.MaxIndexBytes)
+ if err := s.store.SaveTopic(ctx, scope, topicFile, RenderTopic(¤t)); err != nil {
+ return false, fmt.Errorf("memo: save topic: %w", err)
+ }
+ if err := s.store.SaveIndex(ctx, scope, working); err != nil {
+ return false, fmt.Errorf("memo: save index: %w", err)
+ }
+ for _, removed := range removedEntries {
+ if removedTopic := strings.TrimSpace(removed.TopicFile); removedTopic != "" {
+ _ = s.store.DeleteTopic(ctx, scope, removedTopic)
+ }
+ }
+
+ if s.autoExtractIndexReady {
+ s.removeAutoExtractTopicLocked(scope, topicFile)
+ for _, removed := range removedEntries {
+ s.removeAutoExtractTopicLocked(scope, removed.TopicFile)
+ }
+ if indexContainsTopicFile(working, topicFile) {
+ s.trackAutoExtractEntryLocked(scope, current)
+ }
+ }
+
+ s.invalidateCache()
+ return true, nil
+}
+
// normalizeKeyword 统一关键词的空格与大小写处理。
func normalizeKeyword(keyword string) string {
return strings.ToLower(strings.TrimSpace(keyword))
@@ -557,6 +681,23 @@ func indexContainsEntryID(index *Index, entryID string) bool {
return false
}
+// indexContainsTopicFile 判断索引中是否仍保留指定 topic 文件。
+func indexContainsTopicFile(index *Index, topicFile string) bool {
+ if index == nil {
+ return false
+ }
+ topicFile = strings.TrimSpace(topicFile)
+ if topicFile == "" {
+ return false
+ }
+ for _, item := range index.Entries {
+ if strings.TrimSpace(item.TopicFile) == topicFile {
+ return true
+ }
+ }
+ return false
+}
+
// scopesForQuery 将查询范围展开为实际存储分层列表。
func scopesForQuery(scope Scope) []Scope {
switch NormalizeScope(scope) {
@@ -584,3 +725,20 @@ func validateQueryScope(scope Scope) error {
func scopedTopicKey(scope Scope, topicFile string) string {
return string(scope) + ":" + strings.TrimSpace(topicFile)
}
+
+// parseScopedTopicKey 解析自动提取语义去重协议中的 ref 字段。
+func parseScopedTopicKey(ref string) (Scope, string, bool) {
+ parts := strings.SplitN(strings.TrimSpace(ref), ":", 2)
+ if len(parts) != 2 {
+ return "", "", false
+ }
+ scope := Scope(strings.TrimSpace(parts[0]))
+ if err := validateStorageScope(scope); err != nil {
+ return "", "", false
+ }
+ topicFile := strings.TrimSpace(parts[1])
+ if topicFile == "" {
+ return "", "", false
+ }
+ return scope, topicFile, true
+}
diff --git a/internal/memo/types.go b/internal/memo/types.go
index 8572e0e4..0d1a62a0 100644
--- a/internal/memo/types.go
+++ b/internal/memo/types.go
@@ -89,6 +89,35 @@ type RecalledEntry struct {
Content string
}
+// ExtractionAction 表示自动提取器对单条候选记忆的持久化决策。
+type ExtractionAction string
+
+const (
+ // ExtractionActionCreate 表示新增一条记忆。
+ ExtractionActionCreate ExtractionAction = "create"
+ // ExtractionActionUpdate 表示合并更新一条既有自动提取记忆。
+ ExtractionActionUpdate ExtractionAction = "update"
+ // ExtractionActionSkip 表示跳过重复或不值得沉淀的内容。
+ ExtractionActionSkip ExtractionAction = "skip"
+)
+
+// ExtractionCandidate 表示提供给模型做语义去重的既有记忆快照。
+type ExtractionCandidate struct {
+ Ref string `json:"ref"`
+ Scope Scope `json:"scope"`
+ Type Type `json:"type"`
+ Source string `json:"source"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+}
+
+// ExtractionDecision 表示模型针对新旧记忆关系返回的结构化决策。
+type ExtractionDecision struct {
+ Action ExtractionAction
+ Ref string
+ Entry Entry
+}
+
// Store 定义记忆持久化的最小抽象。
type Store interface {
LoadIndex(ctx context.Context, scope Scope) (*Index, error)
@@ -104,6 +133,15 @@ type Extractor interface {
Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error)
}
+// DecisionExtractor 定义带既有记忆快照的语义提取能力。
+type DecisionExtractor interface {
+ ExtractDecisions(
+ ctx context.Context,
+ messages []providertypes.Message,
+ existing []ExtractionCandidate,
+ ) ([]ExtractionDecision, error)
+}
+
// TextGenerator 定义调用 LLM 生成文本的最小能力,用于记忆提取。
// 该接口隔离 provider 细节,避免 memo 包直接依赖 provider 基础设施。
type TextGenerator interface {
diff --git a/internal/runtime/memo.go b/internal/runtime/memo.go
index 477933d7..d2ce6202 100644
--- a/internal/runtime/memo.go
+++ b/internal/runtime/memo.go
@@ -19,6 +19,30 @@ func (s *Service) triggerMemoExtraction(sessionID string, messages []providertyp
s.memoExtractor.Schedule(sessionID, cloneMessages(messages))
}
+// runBoundaryMessagesForMemo 返回当前 run 边界内的消息切片,供自动记忆提取使用。
+func runBoundaryMessagesForMemo(state *runState) []providertypes.Message {
+ if state == nil {
+ return nil
+ }
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ return cloneMessages(state.memoRunMessages)
+}
+
+// appendMemoRunMessage 记录当前 run 内已成功写入 transcript 的消息,作为自动记忆提取边界。
+func appendMemoRunMessage(state *runState, message providertypes.Message) {
+ if state == nil || message.IsEmpty() {
+ return
+ }
+ cloned := cloneMessages([]providertypes.Message{message})
+ if len(cloned) == 0 {
+ return
+ }
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ state.memoRunMessages = append(state.memoRunMessages, cloned[0])
+}
+
// isSuccessfulRememberToolCall 判断工具调用是否成功完成显式记忆写入。
func isSuccessfulRememberToolCall(callName string, result tools.ToolResult, execErr error) bool {
if execErr != nil || result.IsError {
diff --git a/internal/runtime/run.go b/internal/runtime/run.go
index 6887ee55..4fb3d22d 100644
--- a/internal/runtime/run.go
+++ b/internal/runtime/run.go
@@ -472,7 +472,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) {
})
recordAcceptanceTerminal(&state, acceptanceDecision)
s.emitRunScoped(ctx, EventAgentDone, &state, turnOutput.assistant)
- s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun)
+ s.triggerMemoExtraction(state.session.ID, runBoundaryMessagesForMemo(&state), state.rememberedThisRun)
return nil
case acceptance.AcceptanceContinue:
state.lastAcceptanceBlockedReason = strings.TrimSpace(acceptanceDecision.CompletionBlockedReason)
diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go
index f1db6601..a8f72459 100644
--- a/internal/runtime/runtime.go
+++ b/internal/runtime/runtime.go
@@ -145,7 +145,7 @@ type ProviderFactory interface {
// MemoExtractor 定义 runtime 层调用记忆提取的最小能力。
// 通过接口注入避免 runtime 直接依赖 memo 子系统实现细节。
type MemoExtractor interface {
- // Schedule 从消息中安排一次后台记忆提取,失败由实现方自行处理。
+ // Schedule 从当前 run 边界内的消息安排一次后台记忆提取,失败由实现方自行处理。
Schedule(sessionID string, messages []providertypes.Message)
}
diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go
index 9a38cc91..5e1ca56d 100644
--- a/internal/runtime/runtime_test.go
+++ b/internal/runtime/runtime_test.go
@@ -1180,6 +1180,52 @@ func TestServiceRunSchedulesMemoExtractionAfterFinalReply(t *testing.T) {
}
}
+func TestServiceRunSchedulesMemoExtractionFromCurrentRunBoundary(t *testing.T) {
+ t.Parallel()
+
+ manager := newRuntimeConfigManager(t)
+ store := newMemoryStore()
+ session := agentsession.New("existing")
+ session.ID = "session-existing-memo"
+ session.Messages = []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old user")}},
+ {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old assistant")}},
+ }
+ store.sessions[session.ID] = cloneSession(session)
+
+ scripted := &scriptedProvider{
+ streams: [][]providertypes.StreamEvent{
+ {
+ providertypes.NewTextDeltaStreamEvent("new final"),
+ providertypes.NewMessageDoneStreamEvent("stop", nil),
+ },
+ },
+ }
+ service := NewWithFactory(manager, tools.NewRegistry(), store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{})
+ memoExtractor := &stubScheduledMemoExtractor{}
+ service.SetMemoExtractor(memoExtractor)
+
+ err := service.Run(context.Background(), UserInput{
+ SessionID: session.ID,
+ RunID: "run-memo-boundary",
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart("new user")},
+ })
+ if err != nil {
+ t.Fatalf("Run() error = %v", err)
+ }
+ if len(memoExtractor.calls) != 1 {
+ t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls))
+ }
+ messages := memoExtractor.calls[0].messages
+ if len(messages) != 2 {
+ t.Fatalf("scheduled messages = %#v, want current run user+assistant only", messages)
+ }
+ if renderPartsForVerification(messages[0].Parts) != "new user" ||
+ renderPartsForVerification(messages[1].Parts) != "new final" {
+ t.Fatalf("scheduled messages crossed run boundary: %#v", messages)
+ }
+}
+
func TestServiceRunSkipsAutoMemoExtractionAfterRememberTool(t *testing.T) {
t.Parallel()
@@ -5555,6 +5601,8 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) {
}
service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder)
+ memoExtractor := &stubScheduledMemoExtractor{}
+ service.SetMemoExtractor(memoExtractor)
service.compactRunner = &stubCompactRunner{
result: contextcompact.Result{
Messages: []providertypes.Message{
@@ -5617,6 +5665,16 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) {
if renderPartsForTest(saved.Messages[2].Parts) != "recovered" {
t.Fatalf("expected final assistant reply %q, got %q", "recovered", renderPartsForTest(saved.Messages[2].Parts))
}
+ if len(memoExtractor.calls) != 1 {
+ t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls))
+ }
+ memoMessages := memoExtractor.calls[0].messages
+ if len(memoMessages) != 2 {
+ t.Fatalf("memo messages = %+v, want current run user+assistant only after compact", memoMessages)
+ }
+ if renderPartsForTest(memoMessages[0].Parts) != "continue" || renderPartsForTest(memoMessages[1].Parts) != "recovered" {
+ t.Fatalf("memo messages crossed compact boundary: %+v", memoMessages)
+ }
events := collectRuntimeEvents(service.Events())
assertEventSequence(t, events, []EventType{
diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go
index 26d3d9bb..9f00db6f 100644
--- a/internal/runtime/session_mutation.go
+++ b/internal/runtime/session_mutation.go
@@ -29,6 +29,7 @@ func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState,
}); err != nil {
return err
}
+ appendMemoRunMessage(state, message)
s.emitRunScoped(ctx, EventUserMessage, state, message)
return nil
}
@@ -87,7 +88,7 @@ func (s *Service) appendAssistantMessageOnlyAndSave(
}
state.session.Messages = append(state.session.Messages, assistant)
state.touchSession()
- return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
+ if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
SessionID: state.session.ID,
Messages: []providertypes.Message{assistant},
UpdatedAt: state.session.UpdatedAt,
@@ -95,7 +96,11 @@ func (s *Service) appendAssistantMessageOnlyAndSave(
Model: state.session.Model,
Workdir: state.session.Workdir,
HasUnknownUsage: state.session.HasUnknownUsage,
- })
+ }); err != nil {
+ return err
+ }
+ appendMemoRunMessage(state, assistant)
+ return nil
}
// appendSystemMessageAndSave 将系统提醒消息追加到会话并持久化。
@@ -111,7 +116,7 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat
}
state.session.Messages = append(state.session.Messages, message)
state.touchSession()
- return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
+ if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
SessionID: state.session.ID,
Messages: []providertypes.Message{message},
UpdatedAt: state.session.UpdatedAt,
@@ -119,7 +124,11 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat
Model: state.session.Model,
Workdir: state.session.Workdir,
HasUnknownUsage: state.session.HasUnknownUsage,
- })
+ }); err != nil {
+ return err
+ }
+ appendMemoRunMessage(state, message)
+ return nil
}
// appendToolMessageAndSave 将工具原始结果写回会话,持久化时仅追加一条 tool message。
@@ -143,7 +152,11 @@ func (s *Service) appendToolMessageAndSave(
HasUnknownUsage: state.session.HasUnknownUsage,
}
state.mu.Unlock()
- return s.sessionStore.AppendMessages(ctx, input)
+ if err := s.sessionStore.AppendMessages(ctx, input); err != nil {
+ return err
+ }
+ appendMemoRunMessage(state, toolMessage)
+ return nil
}
// normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。
diff --git a/internal/runtime/state.go b/internal/runtime/state.go
index 2440bddf..277d9c6e 100644
--- a/internal/runtime/state.go
+++ b/internal/runtime/state.go
@@ -4,6 +4,7 @@ import (
"sync"
"time"
+ providertypes "neo-code/internal/provider/types"
"neo-code/internal/runtime/controlplane"
"neo-code/internal/runtime/decider"
runtimefacts "neo-code/internal/runtime/facts"
@@ -20,6 +21,7 @@ type runState struct {
effectiveWorkdir string
compactCount int
reactiveCompactAttempts int
+ memoRunMessages []providertypes.Message
rememberedThisRun bool
planningEnabled bool
taskID string
From 1b19d3047b7b9131229eb73bb768e4c30468def9 Mon Sep 17 00:00:00 2001
From: Yumiue <229866007@qq.com>
Date: Sun, 10 May 2026 10:47:56 +0800
Subject: [PATCH 2/5] =?UTF-8?q?fix:web=E7=AB=AFslash=E6=8C=87=E4=BB=A4?=
=?UTF-8?q?=E4=BF=AE=E5=A4=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
web/src/components/chat/ChatInput.test.tsx | 69 ++++++++++++++++++++++
web/src/components/chat/ChatInput.tsx | 25 ++++----
2 files changed, 79 insertions(+), 15 deletions(-)
diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx
index 8c610e7b..bf39b1b2 100644
--- a/web/src/components/chat/ChatInput.test.tsx
+++ b/web/src/components/chat/ChatInput.test.tsx
@@ -25,6 +25,12 @@ vi.mock('./ModelSelector', () => ({
default: () =>
,
}))
+async function submitSlashCommand(command: string) {
+ const textarea = screen.getByRole('textbox') as HTMLTextAreaElement
+ fireEvent.change(textarea, { target: { value: command } })
+ fireEvent.keyDown(textarea, { key: 'Enter' })
+}
+
describe('ChatInput', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -134,4 +140,67 @@ describe('ChatInput', () => {
expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument()
expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument()
})
+ it('executes /memo without session id and shows payload.Content', async () => {
+ mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({
+ payload: {
+ Content: 'User Memo:\n- [user] coding preference',
+ },
+ })
+ render()
+
+ await submitSlashCommand('/memo')
+
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).toHaveBeenCalledWith('', '', 'memo_list', {})
+ })
+ await waitFor(() => {
+ expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content.includes('coding preference'))).toBe(true)
+ })
+ })
+
+ it('uses fallback text when memo payload has no content field', async () => {
+ mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ payload: {} })
+ render()
+
+ await submitSlashCommand('/memo')
+
+ await waitFor(() => {
+ expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content === 'Memo query complete')).toBe(true)
+ })
+ })
+
+ it('executes /remember and /forget without session id', async () => {
+ mockGatewayAPI.executeSystemTool
+ .mockResolvedValueOnce({ payload: { Content: 'Memory saved: [user] keep tests strict' } })
+ .mockResolvedValueOnce({ payload: { Content: 'Removed 1 memo(s) matching \"strict\".' } })
+ render()
+
+ await submitSlashCommand('/remember keep tests strict')
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(1, '', '', 'memo_remember', {
+ type: 'user',
+ title: 'keep tests strict',
+ content: 'keep tests strict',
+ })
+ })
+
+ await submitSlashCommand('/forget strict')
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(2, '', '', 'memo_remove', {
+ keyword: 'strict',
+ scope: 'all',
+ })
+ })
+ })
+
+ it('keeps argument validation for /remember and /forget', async () => {
+ render()
+
+ await submitSlashCommand('/remember')
+ await submitSlashCommand('/forget')
+
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled()
+ })
+ })
})
diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx
index ff4258e4..d5a4bb2e 100644
--- a/web/src/components/chat/ChatInput.tsx
+++ b/web/src/components/chat/ChatInput.tsx
@@ -57,6 +57,13 @@ function buildSlashHelpText(commands: AnySlashCommand[]): string {
return ['可用命令:', ...lines].join('\n')
}
+/** 统一提取系统工具返回文本,兼容 payload.content 与 payload.Content。 */
+function extractSystemToolContent(result: unknown, fallback: string): string {
+ const payload = (result as { payload?: { content?: string; Content?: string } } | null)?.payload
+ const content = payload?.content ?? payload?.Content
+ return content || fallback
+}
+
export default function ChatInput() {
const gatewayAPI = useGatewayAPI()
const text = useComposerStore((state) => state.composerText)
@@ -149,13 +156,9 @@ export default function ChatInput() {
return true
}
case '/memo': {
- if (!isValidSessionId(currentSessionId)) {
- useUIStore.getState().showToast('Send a message first to start a session', 'error')
- return true
- }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {})
- addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete')
+ addSystemMessage(extractSystemToolContent(result, 'Memo query complete'))
} catch (err) {
console.error('Memo list failed:', err)
useUIStore.getState().showToast('Failed to query memo', 'error')
@@ -167,17 +170,13 @@ export default function ChatInput() {
useUIStore.getState().showToast('Usage: /remember ', 'error')
return true
}
- if (!isValidSessionId(currentSessionId)) {
- useUIStore.getState().showToast('Send a message first to start a session', 'error')
- return true
- }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', {
type: 'user',
title: argument,
content: argument,
})
- addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved')
+ addSystemMessage(extractSystemToolContent(result, 'Memo saved'))
} catch (err) {
console.error('Remember failed:', err)
useUIStore.getState().showToast('Failed to save memo', 'error')
@@ -189,16 +188,12 @@ export default function ChatInput() {
useUIStore.getState().showToast('Usage: /forget ', 'error')
return true
}
- if (!isValidSessionId(currentSessionId)) {
- useUIStore.getState().showToast('Send a message first to start a session', 'error')
- return true
- }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', {
keyword: argument,
scope: 'all',
})
- addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted')
+ addSystemMessage(extractSystemToolContent(result, 'Memo deleted'))
} catch (err) {
console.error('Forget failed:', err)
useUIStore.getState().showToast('Failed to delete memo', 'error')
From 5cde1d19f005baa0d52abe43ae30660260b466ac Mon Sep 17 00:00:00 2001
From: Yumiue <229866007@qq.com>
Date: Sun, 10 May 2026 10:36:56 +0800
Subject: [PATCH 3/5] =?UTF-8?q?fix(memo):memo=E5=88=9D=E6=AD=A5=E4=BF=AE?=
=?UTF-8?q?=E7=90=86,run=E6=88=AA=E5=8F=96=EF=BC=8Cllm=E6=8F=90=E5=8F=96?=
=?UTF-8?q?=E6=97=B6=E5=8E=BB=E9=87=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
internal/app/bootstrap_test.go | 42 +++++++
internal/context/projection.go | 36 ++++++
internal/context/projection_test.go | 44 ++++++++
internal/memo/auto_extractor.go | 64 +++++++++++
internal/memo/auto_extractor_test.go | 144 ++++++++++++++++++++++++
internal/memo/llm_extractor.go | 123 ++++++++++++++++++---
internal/memo/llm_extractor_test.go | 61 +++++++++--
internal/memo/service.go | 158 +++++++++++++++++++++++++++
internal/memo/types.go | 38 +++++++
internal/runtime/memo.go | 24 ++++
internal/runtime/run.go | 2 +-
internal/runtime/runtime.go | 2 +-
internal/runtime/runtime_test.go | 58 ++++++++++
internal/runtime/session_mutation.go | 23 +++-
internal/runtime/state.go | 2 +
15 files changed, 793 insertions(+), 28 deletions(-)
diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go
index 3760dded..917823cb 100644
--- a/internal/app/bootstrap_test.go
+++ b/internal/app/bootstrap_test.go
@@ -1675,6 +1675,48 @@ func TestNewMemoExtractorAdapterBuildsProviderSafeMemoWindow(t *testing.T) {
}
}
+func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) {
+ t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token")
+ cfg := config.StaticDefaults().Clone()
+ cfg.SelectedProvider = config.OpenAIName
+ cfg.Memo.ExtractRecentMessages = 3
+ manager := config.NewManager(config.NewLoader("", &cfg))
+
+ providerStub := &stubMemoProvider{
+ generate: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error {
+ if len(req.Messages) != 12 {
+ t.Fatalf("unexpected memo window length %d, want full run: %+v", len(req.Messages), req.Messages)
+ }
+ events <- providertypes.NewTextDeltaStreamEvent(`[]`)
+ events <- providertypes.NewMessageDoneStreamEvent("stop", nil)
+ return nil
+ },
+ }
+ factory := &stubMemoProviderFactory{provider: providerStub}
+ scheduler := &stubMemoExtractorScheduler{}
+ extractor := newMemoExtractorAdapter(factory, manager, scheduler)
+
+ inputMessages := make([]providertypes.Message, 0, 12)
+ for index := 0; index < 12; index++ {
+ inputMessages = append(inputMessages, providertypes.Message{
+ Role: providertypes.RoleUser,
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart(fmt.Sprintf("message-%02d", index))},
+ })
+ }
+ extractor.Schedule("session-1", inputMessages)
+ if !scheduler.called || scheduler.extractor == nil {
+ t.Fatalf("expected scheduler to receive extractor")
+ }
+
+ _, err := scheduler.extractor.Extract(context.Background(), inputMessages)
+ if err != nil {
+ t.Fatalf("extractor.Extract() error = %v", err)
+ }
+ if !factory.called {
+ t.Fatalf("expected provider factory Build to be called")
+ }
+}
+
func TestNewMemoExtractorAdapterKeepsScheduledConfigSnapshot(t *testing.T) {
t.Setenv(config.OpenAIDefaultAPIKeyEnv, "openai-token")
t.Setenv(config.QiniuDefaultAPIKeyEnv, "qiniu-token")
diff --git a/internal/context/projection.go b/internal/context/projection.go
index ac7da4b3..298f7e8c 100644
--- a/internal/context/projection.go
+++ b/internal/context/projection.go
@@ -79,6 +79,42 @@ func BuildRecentMessagesForModel(messages []providertypes.Message, limit int) []
return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected)))
}
+// BuildMemoExtractionMessagesForModel 构造完整 run 的 provider-safe 记忆提取上下文。
+func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []providertypes.Message {
+ if len(messages) == 0 {
+ return nil
+ }
+
+ keep := make([]bool, len(messages))
+ for index := 0; index < len(messages); index++ {
+ message := messages[index]
+ if message.Role == providertypes.RoleTool {
+ continue
+ }
+
+ if message.Role == providertypes.RoleAssistant && len(message.ToolCalls) > 0 {
+ for _, spanIndex := range matchedToolCallSpan(messages, index) {
+ keep[spanIndex] = true
+ }
+ continue
+ }
+
+ keep[index] = true
+ }
+
+ selected := make([]providertypes.Message, 0, len(messages))
+ for index, message := range messages {
+ if keep[index] {
+ selected = append(selected, message)
+ }
+ }
+ if len(selected) == 0 {
+ return nil
+ }
+
+ return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected)))
+}
+
// matchedToolCallSpan 返回 assistant tool call 与其完整 tool 响应组成的合法窗口下标集合。
func matchedToolCallSpan(messages []providertypes.Message, assistantIndex int) []int {
if assistantIndex < 0 || assistantIndex >= len(messages) {
diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go
index 372e3c3f..3b0c932a 100644
--- a/internal/context/projection_test.go
+++ b/internal/context/projection_test.go
@@ -140,6 +140,50 @@ func TestBuildRecentMessagesForModelKeepsOnlyRecentValidAnchors(t *testing.T) {
}
}
+func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T) {
+ t.Parallel()
+
+ messages := []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}},
+ {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan")}},
+ {
+ Role: providertypes.RoleAssistant,
+ ToolCalls: []providertypes.ToolCall{
+ {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
+ },
+ },
+ {
+ Role: providertypes.RoleTool,
+ ToolCallID: "call-1",
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")},
+ ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"},
+ },
+ {
+ Role: providertypes.RoleAssistant,
+ ToolCalls: []providertypes.ToolCall{
+ {ID: "call-missing", Name: "bash", Arguments: `{}`},
+ },
+ },
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("last")}},
+ }
+
+ projected := BuildMemoExtractionMessagesForModel(messages)
+ if len(projected) != 4 {
+ t.Fatalf("len(projected) = %d, want 4: %+v", len(projected), projected)
+ }
+ if renderDisplayParts(projected[0].Parts) != "first" || renderDisplayParts(projected[3].Parts) != "last" {
+ t.Fatalf("expected full run user messages to remain, got %+v", projected)
+ }
+ if projected[1].Role != providertypes.RoleAssistant || len(projected[1].ToolCalls) != 1 {
+ t.Fatalf("expected complete assistant tool span, got %+v", projected[1])
+ }
+ if projected[2].Role != providertypes.RoleTool ||
+ !strings.Contains(renderDisplayParts(projected[2].Parts), "tool result") ||
+ projected[2].ToolMetadata != nil {
+ t.Fatalf("expected projected tool result, got %+v", projected[2])
+ }
+}
+
func TestBuildRecentMessagesForModelRespectsAbsoluteMessageBudget(t *testing.T) {
t.Parallel()
diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go
index aaadfec9..a1ccb43b 100644
--- a/internal/memo/auto_extractor.go
+++ b/internal/memo/auto_extractor.go
@@ -225,6 +225,24 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider
ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout)
defer cancel()
+ if decisionExtractor, ok := extractor.(DecisionExtractor); ok {
+ existing, err := a.svc.autoExtractionCandidates(ctx)
+ if err != nil {
+ a.logError("memo: auto extract load candidates failed: %v", err)
+ return false
+ }
+ decisions, err := decisionExtractor.ExtractDecisions(ctx, messages, existing)
+ if err != nil {
+ if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
+ a.logError("memo: auto extract skipped (protocol_mismatch): %v", err)
+ return true
+ }
+ a.logError("memo: auto extract failed: %v", err)
+ return false
+ }
+ return a.applyExtractionDecisions(ctx, decisions)
+ }
+
entries, err := extractor.Extract(ctx, messages)
if err != nil {
if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
@@ -265,6 +283,52 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider
return succeeded
}
+// applyExtractionDecisions 应用模型返回的 create/update/skip 决策,并保留本地精确去重兜底。
+func (a *AutoExtractor) applyExtractionDecisions(ctx context.Context, decisions []ExtractionDecision) bool {
+ if len(decisions) == 0 {
+ return true
+ }
+
+ seenCreates := make(map[string]struct{}, len(decisions))
+ succeeded := true
+ for _, decision := range decisions {
+ switch decision.Action {
+ case ExtractionActionCreate:
+ entry := decision.Entry
+ entry.Source = SourceAutoExtract
+ key := autoExtractDedupKey(entry)
+ if key == "" {
+ continue
+ }
+ if _, exists := seenCreates[key]; exists {
+ continue
+ }
+ added, err := a.svc.addAutoExtractIfAbsent(ctx, entry)
+ if err != nil {
+ a.logError("memo: auto extract add failed: %v", err)
+ succeeded = false
+ continue
+ }
+ seenCreates[key] = struct{}{}
+ if !added {
+ continue
+ }
+ case ExtractionActionUpdate:
+ entry := decision.Entry
+ entry.Source = SourceAutoExtract
+ _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, entry)
+ if err != nil {
+ a.logError("memo: auto extract update failed: %v", err)
+ succeeded = false
+ continue
+ }
+ case ExtractionActionSkip:
+ continue
+ }
+ }
+ return succeeded
+}
+
// autoExtractDedupKey 生成自动提取条目的精确去重键。
func autoExtractDedupKey(entry Entry) string {
title := NormalizeTitle(entry.Title)
diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go
index 2f188d07..6c4e5835 100644
--- a/internal/memo/auto_extractor_test.go
+++ b/internal/memo/auto_extractor_test.go
@@ -38,6 +38,33 @@ func (s *stubMemoExtractor) Calls() int {
return s.callCount
}
+type stubDecisionMemoExtractor struct {
+ mu sync.Mutex
+ callCount int
+ candidates []ExtractionCandidate
+ decisions []ExtractionDecision
+ err error
+}
+
+func (s *stubDecisionMemoExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) {
+ return nil, nil
+}
+
+func (s *stubDecisionMemoExtractor) ExtractDecisions(
+ ctx context.Context,
+ messages []providertypes.Message,
+ existing []ExtractionCandidate,
+) ([]ExtractionDecision, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.callCount++
+ s.candidates = append([]ExtractionCandidate(nil), existing...)
+ if s.err != nil {
+ return nil, s.err
+ }
+ return append([]ExtractionDecision(nil), s.decisions...), nil
+}
+
func newAutoExtractorTestService(t *testing.T) *Service {
t.Helper()
store := NewFileStore(t.TempDir(), t.TempDir())
@@ -214,6 +241,123 @@ func TestAutoExtractorSuppressesExactDuplicates(t *testing.T) {
})
}
+func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing.T) {
+ svc := newAutoExtractorTestService(t)
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeFeedback,
+ Title: "测试策略",
+ Content: "用户要求修改后先跑测试。",
+ Source: SourceAutoExtract,
+ }); err != nil {
+ t.Fatalf("seed auto Add() error = %v", err)
+ }
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeUser,
+ Title: "语言偏好",
+ Content: "用户手动保存偏好中文回复。",
+ Source: SourceUserManual,
+ }); err != nil {
+ t.Fatalf("seed manual Add() error = %v", err)
+ }
+
+ candidates, err := svc.autoExtractionCandidates(context.Background())
+ if err != nil {
+ t.Fatalf("autoExtractionCandidates() error = %v", err)
+ }
+ var autoRef, manualRef string
+ for _, candidate := range candidates {
+ switch candidate.Source {
+ case SourceAutoExtract:
+ autoRef = candidate.Ref
+ case SourceUserManual:
+ manualRef = candidate.Ref
+ }
+ }
+ if autoRef == "" || manualRef == "" {
+ t.Fatalf("expected refs, auto=%q manual=%q candidates=%+v", autoRef, manualRef, candidates)
+ }
+
+ extractor := &stubDecisionMemoExtractor{
+ decisions: []ExtractionDecision{
+ {
+ Action: ExtractionActionUpdate,
+ Ref: autoRef,
+ Entry: Entry{
+ Title: "测试策略",
+ Content: "用户要求修改后先跑相关测试。",
+ Keywords: []string{"test"},
+ },
+ },
+ {
+ Action: ExtractionActionUpdate,
+ Ref: manualRef,
+ Entry: Entry{
+ Title: "语言偏好",
+ Content: "不应覆盖手动记忆。",
+ },
+ },
+ },
+ }
+ auto := NewAutoExtractor(extractor, svc, time.Second)
+ auto.debounce = 5 * time.Millisecond
+ auto.logf = func(string, ...any) {}
+ registerAutoExtractorCleanup(t, auto)
+
+ auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("更新测试策略")}}})
+ waitFor(t, time.Second, func() bool {
+ extractor.mu.Lock()
+ defer extractor.mu.Unlock()
+ return extractor.callCount == 1
+ })
+
+ recall, err := svc.Recall(context.Background(), "测试策略", ScopeProject)
+ if err != nil {
+ t.Fatalf("Recall(auto) error = %v", err)
+ }
+ if len(recall) != 1 || !strings.Contains(recall[0].Content, "相关测试") {
+ t.Fatalf("expected auto memory to be updated, got %+v", recall)
+ }
+ manualRecall, err := svc.Recall(context.Background(), "语言偏好", ScopeUser)
+ if err != nil {
+ t.Fatalf("Recall(manual) error = %v", err)
+ }
+ if len(manualRecall) != 1 || strings.Contains(manualRecall[0].Content, "不应覆盖") {
+ t.Fatalf("manual memory should not be overwritten, got %+v", manualRecall)
+ }
+ if len(extractor.candidates) != 2 {
+ t.Fatalf("expected existing candidates to be provided, got %+v", extractor.candidates)
+ }
+}
+
+func TestAutoExtractorSemanticCreateStillUsesExactDedup(t *testing.T) {
+ svc := newAutoExtractorTestService(t)
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeUser,
+ Title: "中文回复",
+ Content: "用户偏好中文回复。",
+ Source: SourceAutoExtract,
+ }); err != nil {
+ t.Fatalf("seed Add() error = %v", err)
+ }
+
+ extractor := &stubDecisionMemoExtractor{
+ decisions: []ExtractionDecision{
+ {Action: ExtractionActionCreate, Entry: Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}},
+ {Action: ExtractionActionCreate, Entry: Entry{Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}},
+ },
+ }
+ auto := NewAutoExtractor(extractor, svc, time.Second)
+ auto.debounce = 5 * time.Millisecond
+ auto.logf = func(string, ...any) {}
+ registerAutoExtractorCleanup(t, auto)
+
+ auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("dedupe")}}})
+ waitFor(t, time.Second, func() bool {
+ entries, err := svc.List(context.Background(), ScopeAll)
+ return err == nil && len(entries) == 2
+ })
+}
+
func TestAutoExtractorUsesTimeoutContext(t *testing.T) {
svc := newAutoExtractorTestService(t)
extractor := &stubMemoExtractor{
diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go
index 9d039f44..621446de 100644
--- a/internal/memo/llm_extractor.go
+++ b/internal/memo/llm_extractor.go
@@ -27,6 +27,8 @@ type LLMExtractor struct {
}
type extractedEntry struct {
+ Action string `json:"action"`
+ Ref string `json:"ref"`
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
@@ -45,8 +47,27 @@ func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtrac
}
}
-// Extract 从最近对话中提取可跨会话持久化的记忆条目。
+// Extract 从当前 run 对话中提取可跨会话持久化的新增记忆条目。
func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) {
+ decisions, err := e.ExtractDecisions(ctx, messages, nil)
+ if err != nil {
+ return nil, err
+ }
+ entries := make([]Entry, 0, len(decisions))
+ for _, decision := range decisions {
+ if decision.Action == ExtractionActionCreate {
+ entries = append(entries, decision.Entry)
+ }
+ }
+ return entries, nil
+}
+
+// ExtractDecisions 从当前 run 对话中提取记忆,并结合既有记忆输出新增、合并或跳过决策。
+func (e *LLMExtractor) ExtractDecisions(
+ ctx context.Context,
+ messages []providertypes.Message,
+ existing []ExtractionCandidate,
+) ([]ExtractionDecision, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
@@ -54,12 +75,12 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes
return nil, errors.New("memo: text generator is nil")
}
- recent := agentcontext.BuildRecentMessagesForModel(messages, e.recentMessageLimit)
- if len(recent) == 0 || !containsUserMessage(recent) {
+ runMessages := agentcontext.BuildMemoExtractionMessagesForModel(messages)
+ if len(runMessages) == 0 || !containsUserMessage(runMessages) {
return nil, nil
}
- response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now()), recent)
+ response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now(), existing), runMessages)
if err != nil {
return nil, err
}
@@ -77,23 +98,29 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes
return nil, fmt.Errorf("memo: parse extraction response: %w", err)
}
- entries := make([]Entry, 0, len(extracted))
+ decisions := make([]ExtractionDecision, 0, len(extracted))
for _, item := range extracted {
- entry, ok := toMemoEntry(item)
+ decision, ok := toExtractionDecision(item)
if !ok {
continue
}
- entries = append(entries, entry)
+ decisions = append(decisions, decision)
}
- return entries, nil
+ return decisions, nil
}
// buildExtractionPrompt 构造记忆提取专用的 system prompt。
-func buildExtractionPrompt(now time.Time) string {
+func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string {
currentDate := now.In(time.Local).Format("2006-01-02")
+ existingJSON := "[]"
+ if len(existing) > 0 {
+ if data, err := json.Marshal(existing); err == nil {
+ existingJSON = string(data)
+ }
+ }
return strings.TrimSpace(fmt.Sprintf(`
你是一个记忆提取助手(memory extraction assistant)。
-分析最近对话中值得跨会话持久记住的信息,并返回严格 JSON 数组。
+分析当前 run 对话中值得跨会话持久记住的信息,并结合既有记忆完成语义去重,返回严格 JSON 数组。
当前本地日期:%s
如果对话中出现“今天、明天、下周二”等相对日期,必须先转换为绝对日期再写入 content。
@@ -109,11 +136,18 @@ func buildExtractionPrompt(now time.Time) string {
2. 不要提取通用编程知识、代码结构、文件路径、Git 历史。
3. 每条记忆必须具体、可操作。
4. 没有值得记住的信息时,返回 []。
-5. 输出必须是 JSON 数组,不要输出任何额外解释。
+5. 如果新信息与既有记忆语义相同或只是轻微改写,输出 action="skip"。
+6. 如果新信息能补充或修正既有 source="extractor_auto" 的记忆,输出 action="update" 并填写目标 ref。
+7. 不允许 update source 不是 "extractor_auto" 的既有记忆;这类相近内容只能 skip。
+8. 如果是全新的可持久化信息,输出 action="create"。
+9. 输出必须是 JSON 数组,不要输出任何额外解释。
+
+既有记忆候选(JSON):
+%s
输出格式:
-[{"type":"user","title":"...","content":"...","keywords":["..."]}]
-`, currentDate))
+[{"action":"create","type":"user","title":"...","content":"...","keywords":["..."]},{"action":"update","ref":"project:p.md","title":"...","content":"...","keywords":["..."]},{"action":"skip","ref":"user:u.md"}]
+`, currentDate, existingJSON))
}
// containsUserMessage 检查待提取消息中是否包含用户输入。
@@ -148,6 +182,69 @@ func toMemoEntry(item extractedEntry) (Entry, bool) {
}, true
}
+// toExtractionDecision 将 LLM 输出收敛为自动提取持久化决策。
+func toExtractionDecision(item extractedEntry) (ExtractionDecision, bool) {
+ action := parseExtractionAction(item.Action)
+ if action == "" {
+ return ExtractionDecision{}, false
+ }
+ if action == ExtractionActionSkip {
+ return ExtractionDecision{
+ Action: action,
+ Ref: strings.TrimSpace(item.Ref),
+ }, true
+ }
+
+ if action == ExtractionActionUpdate {
+ ref := strings.TrimSpace(item.Ref)
+ if ref == "" {
+ return ExtractionDecision{}, false
+ }
+ entry, ok := toMemoUpdateEntry(item)
+ if !ok {
+ return ExtractionDecision{}, false
+ }
+ return ExtractionDecision{Action: action, Ref: ref, Entry: entry}, true
+ }
+
+ entry, ok := toMemoEntry(item)
+ if !ok {
+ return ExtractionDecision{}, false
+ }
+ return ExtractionDecision{Action: action, Entry: entry}, true
+}
+
+// toMemoUpdateEntry 将 update 决策中的可变字段收敛为 Entry 片段。
+func toMemoUpdateEntry(item extractedEntry) (Entry, bool) {
+ title := NormalizeTitle(item.Title)
+ content := strings.TrimSpace(item.Content)
+ if title == "" || content == "" {
+ return Entry{}, false
+ }
+ return Entry{
+ Title: title,
+ Content: content,
+ Keywords: normalizeKeywords(item.Keywords),
+ Source: SourceAutoExtract,
+ }, true
+}
+
+// parseExtractionAction 解析模型决策动作,并兼容旧格式中缺省 action 的 create 输出。
+func parseExtractionAction(action string) ExtractionAction {
+ switch ExtractionAction(strings.ToLower(strings.TrimSpace(action))) {
+ case "":
+ return ExtractionActionCreate
+ case ExtractionActionCreate:
+ return ExtractionActionCreate
+ case ExtractionActionUpdate:
+ return ExtractionActionUpdate
+ case ExtractionActionSkip:
+ return ExtractionActionSkip
+ default:
+ return ""
+ }
+}
+
// normalizeKeywords 规范化关键词列表,移除空值和重复值。
func normalizeKeywords(keywords []string) []string {
if len(keywords) == 0 {
diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go
index 3ccf9ee9..06813229 100644
--- a/internal/memo/llm_extractor_test.go
+++ b/internal/memo/llm_extractor_test.go
@@ -196,8 +196,8 @@ func TestLLMExtractorExtractCancelledContext(t *testing.T) {
}
}
-// TestLLMExtractorExtractUsesRecentNonToolMessages 验证只取最近 10 条非 tool 消息。
-func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) {
+// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息而非固定 recent window。
+func TestLLMExtractorExtractUsesFullRunMessages(t *testing.T) {
generator := &stubTextGenerator{response: `[]`}
extractor := NewLLMExtractor(generator, 10)
@@ -219,19 +219,64 @@ func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) {
if err != nil {
t.Fatalf("Extract() error = %v", err)
}
- if len(generator.messages) != 10 {
- t.Fatalf("len(generator.messages) = %d, want 10", len(generator.messages))
+ if len(generator.messages) != 12 {
+ t.Fatalf("len(generator.messages) = %d, want 12", len(generator.messages))
}
for _, message := range generator.messages {
if message.Role == providertypes.RoleTool {
t.Fatalf("unexpected tool message in extraction context: %#v", message)
}
}
- if renderMemoParts(generator.messages[0].Parts) != "user-c" ||
- renderMemoParts(generator.messages[9].Parts) != "user-l" {
- t.Fatalf("unexpected recent window: first=%q last=%q",
+ if renderMemoParts(generator.messages[0].Parts) != "user-a" ||
+ renderMemoParts(generator.messages[11].Parts) != "user-l" {
+ t.Fatalf("unexpected run window: first=%q last=%q",
renderMemoParts(generator.messages[0].Parts),
- renderMemoParts(generator.messages[9].Parts))
+ renderMemoParts(generator.messages[11].Parts))
+ }
+}
+
+func TestLLMExtractorExtractDecisionsIncludesExistingCandidates(t *testing.T) {
+ generator := &stubTextGenerator{
+ response: `[{"action":"skip","ref":"user:u.md"},{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}]`,
+ }
+ extractor := NewLLMExtractor(generator, 10)
+
+ decisions, err := extractor.ExtractDecisions(
+ context.Background(),
+ []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑相关测试。")}},
+ },
+ []ExtractionCandidate{
+ {
+ Ref: "project:p.md",
+ Scope: ScopeProject,
+ Type: TypeFeedback,
+ Source: SourceAutoExtract,
+ Title: "测试策略",
+ Content: "用户要求修改后先跑测试。",
+ },
+ },
+ )
+ if err != nil {
+ t.Fatalf("ExtractDecisions() error = %v", err)
+ }
+ if len(decisions) != 2 {
+ t.Fatalf("len(decisions) = %d, want 2", len(decisions))
+ }
+ if decisions[0].Action != ExtractionActionSkip || decisions[0].Ref != "user:u.md" {
+ t.Fatalf("unexpected skip decision: %+v", decisions[0])
+ }
+ if decisions[1].Action != ExtractionActionUpdate || decisions[1].Ref != "project:p.md" {
+ t.Fatalf("unexpected update decision: %+v", decisions[1])
+ }
+ if decisions[1].Entry.Type != "" {
+ t.Fatalf("update decision should not require type, got %+v", decisions[1].Entry)
+ }
+ if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) {
+ t.Fatalf("prompt should include existing memory candidates, got %q", generator.prompt)
+ }
+ if !strings.Contains(generator.prompt, `action="update"`) || !strings.Contains(generator.prompt, `source="extractor_auto"`) {
+ t.Fatalf("prompt should describe semantic dedupe protocol, got %q", generator.prompt)
}
}
diff --git a/internal/memo/service.go b/internal/memo/service.go
index c7415c50..42551f26 100644
--- a/internal/memo/service.go
+++ b/internal/memo/service.go
@@ -68,6 +68,130 @@ func (s *Service) addAutoExtractIfAbsent(ctx context.Context, entry Entry) (bool
return true, nil
}
+// autoExtractionCandidates 加载既有记忆快照,供模型在提取时做语义去重判断。
+func (s *Service) autoExtractionCandidates(ctx context.Context) ([]ExtractionCandidate, error) {
+ if s == nil {
+ return nil, nil
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ candidates := make([]ExtractionCandidate, 0)
+ for _, scope := range supportedStorageScopes() {
+ index, err := s.loadIndexLocked(ctx, scope)
+ if err != nil {
+ return nil, err
+ }
+ for _, entry := range index.Entries {
+ topicFile := strings.TrimSpace(entry.TopicFile)
+ if topicFile == "" {
+ continue
+ }
+ topicContent, err := s.store.LoadTopic(ctx, scope, topicFile)
+ if err != nil {
+ continue
+ }
+ source, content := parseTopicSourceAndContent(topicContent)
+ candidates = append(candidates, ExtractionCandidate{
+ Ref: scopedTopicKey(scope, topicFile),
+ Scope: scope,
+ Type: entry.Type,
+ Source: source,
+ Title: entry.Title,
+ Content: content,
+ })
+ }
+ }
+ return candidates, nil
+}
+
+// updateAutoExtractIfAllowed 按 ref 更新既有自动提取记忆,显式记忆不会被后台流程覆盖。
+func (s *Service) updateAutoExtractIfAllowed(ctx context.Context, ref string, next Entry) (bool, error) {
+ next.Title = NormalizeTitle(next.Title)
+ next.Content = strings.TrimSpace(next.Content)
+ next.Keywords = normalizeKeywords(next.Keywords)
+ if next.Title == "" {
+ return false, fmt.Errorf("memo: title is empty")
+ }
+ if next.Content == "" {
+ return false, fmt.Errorf("memo: content is empty")
+ }
+ if err := s.ensureAutoExtractIndex(ctx); err != nil {
+ return false, err
+ }
+
+ scope, topicFile, ok := parseScopedTopicKey(ref)
+ if !ok {
+ return false, nil
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ index, err := s.loadIndexLocked(ctx, scope)
+ if err != nil {
+ return false, err
+ }
+ working := cloneIndex(index)
+
+ targetIndex := -1
+ var current Entry
+ for idx, existing := range working.Entries {
+ if strings.TrimSpace(existing.TopicFile) == topicFile {
+ targetIndex = idx
+ current = existing
+ break
+ }
+ }
+ if targetIndex < 0 {
+ return false, nil
+ }
+
+ topicContent, err := s.store.LoadTopic(ctx, scope, topicFile)
+ if err != nil {
+ return false, nil
+ }
+ source, _ := parseTopicSourceAndContent(topicContent)
+ if source != SourceAutoExtract {
+ return false, nil
+ }
+
+ current.Title = next.Title
+ current.Content = next.Content
+ current.Keywords = next.Keywords
+ current.Source = SourceAutoExtract
+ current.TopicFile = topicFile
+ working.Entries[targetIndex] = current
+ working.UpdatedAt = time.Now()
+
+ removedEntries := trimIndexEntries(working, s.config.MaxEntries, s.config.MaxIndexBytes)
+ if err := s.store.SaveTopic(ctx, scope, topicFile, RenderTopic(¤t)); err != nil {
+ return false, fmt.Errorf("memo: save topic: %w", err)
+ }
+ if err := s.store.SaveIndex(ctx, scope, working); err != nil {
+ return false, fmt.Errorf("memo: save index: %w", err)
+ }
+ for _, removed := range removedEntries {
+ if removedTopic := strings.TrimSpace(removed.TopicFile); removedTopic != "" {
+ _ = s.store.DeleteTopic(ctx, scope, removedTopic)
+ }
+ }
+
+ if s.autoExtractIndexReady {
+ s.removeAutoExtractTopicLocked(scope, topicFile)
+ for _, removed := range removedEntries {
+ s.removeAutoExtractTopicLocked(scope, removed.TopicFile)
+ }
+ if indexContainsTopicFile(working, topicFile) {
+ s.trackAutoExtractEntryLocked(scope, current)
+ }
+ }
+
+ s.invalidateCache()
+ return true, nil
+}
+
// normalizeKeyword 统一关键词的空格与大小写处理。
func normalizeKeyword(keyword string) string {
return strings.ToLower(strings.TrimSpace(keyword))
@@ -557,6 +681,23 @@ func indexContainsEntryID(index *Index, entryID string) bool {
return false
}
+// indexContainsTopicFile 判断索引中是否仍保留指定 topic 文件。
+func indexContainsTopicFile(index *Index, topicFile string) bool {
+ if index == nil {
+ return false
+ }
+ topicFile = strings.TrimSpace(topicFile)
+ if topicFile == "" {
+ return false
+ }
+ for _, item := range index.Entries {
+ if strings.TrimSpace(item.TopicFile) == topicFile {
+ return true
+ }
+ }
+ return false
+}
+
// scopesForQuery 将查询范围展开为实际存储分层列表。
func scopesForQuery(scope Scope) []Scope {
switch NormalizeScope(scope) {
@@ -584,3 +725,20 @@ func validateQueryScope(scope Scope) error {
func scopedTopicKey(scope Scope, topicFile string) string {
return string(scope) + ":" + strings.TrimSpace(topicFile)
}
+
+// parseScopedTopicKey 解析自动提取语义去重协议中的 ref 字段。
+func parseScopedTopicKey(ref string) (Scope, string, bool) {
+ parts := strings.SplitN(strings.TrimSpace(ref), ":", 2)
+ if len(parts) != 2 {
+ return "", "", false
+ }
+ scope := Scope(strings.TrimSpace(parts[0]))
+ if err := validateStorageScope(scope); err != nil {
+ return "", "", false
+ }
+ topicFile := strings.TrimSpace(parts[1])
+ if topicFile == "" {
+ return "", "", false
+ }
+ return scope, topicFile, true
+}
diff --git a/internal/memo/types.go b/internal/memo/types.go
index 8572e0e4..0d1a62a0 100644
--- a/internal/memo/types.go
+++ b/internal/memo/types.go
@@ -89,6 +89,35 @@ type RecalledEntry struct {
Content string
}
+// ExtractionAction 表示自动提取器对单条候选记忆的持久化决策。
+type ExtractionAction string
+
+const (
+ // ExtractionActionCreate 表示新增一条记忆。
+ ExtractionActionCreate ExtractionAction = "create"
+ // ExtractionActionUpdate 表示合并更新一条既有自动提取记忆。
+ ExtractionActionUpdate ExtractionAction = "update"
+ // ExtractionActionSkip 表示跳过重复或不值得沉淀的内容。
+ ExtractionActionSkip ExtractionAction = "skip"
+)
+
+// ExtractionCandidate 表示提供给模型做语义去重的既有记忆快照。
+type ExtractionCandidate struct {
+ Ref string `json:"ref"`
+ Scope Scope `json:"scope"`
+ Type Type `json:"type"`
+ Source string `json:"source"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+}
+
+// ExtractionDecision 表示模型针对新旧记忆关系返回的结构化决策。
+type ExtractionDecision struct {
+ Action ExtractionAction
+ Ref string
+ Entry Entry
+}
+
// Store 定义记忆持久化的最小抽象。
type Store interface {
LoadIndex(ctx context.Context, scope Scope) (*Index, error)
@@ -104,6 +133,15 @@ type Extractor interface {
Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error)
}
+// DecisionExtractor 定义带既有记忆快照的语义提取能力。
+type DecisionExtractor interface {
+ ExtractDecisions(
+ ctx context.Context,
+ messages []providertypes.Message,
+ existing []ExtractionCandidate,
+ ) ([]ExtractionDecision, error)
+}
+
// TextGenerator 定义调用 LLM 生成文本的最小能力,用于记忆提取。
// 该接口隔离 provider 细节,避免 memo 包直接依赖 provider 基础设施。
type TextGenerator interface {
diff --git a/internal/runtime/memo.go b/internal/runtime/memo.go
index 477933d7..d2ce6202 100644
--- a/internal/runtime/memo.go
+++ b/internal/runtime/memo.go
@@ -19,6 +19,30 @@ func (s *Service) triggerMemoExtraction(sessionID string, messages []providertyp
s.memoExtractor.Schedule(sessionID, cloneMessages(messages))
}
+// runBoundaryMessagesForMemo 返回当前 run 边界内的消息切片,供自动记忆提取使用。
+func runBoundaryMessagesForMemo(state *runState) []providertypes.Message {
+ if state == nil {
+ return nil
+ }
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ return cloneMessages(state.memoRunMessages)
+}
+
+// appendMemoRunMessage 记录当前 run 内已成功写入 transcript 的消息,作为自动记忆提取边界。
+func appendMemoRunMessage(state *runState, message providertypes.Message) {
+ if state == nil || message.IsEmpty() {
+ return
+ }
+ cloned := cloneMessages([]providertypes.Message{message})
+ if len(cloned) == 0 {
+ return
+ }
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ state.memoRunMessages = append(state.memoRunMessages, cloned[0])
+}
+
// isSuccessfulRememberToolCall 判断工具调用是否成功完成显式记忆写入。
func isSuccessfulRememberToolCall(callName string, result tools.ToolResult, execErr error) bool {
if execErr != nil || result.IsError {
diff --git a/internal/runtime/run.go b/internal/runtime/run.go
index 6887ee55..4fb3d22d 100644
--- a/internal/runtime/run.go
+++ b/internal/runtime/run.go
@@ -472,7 +472,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) {
})
recordAcceptanceTerminal(&state, acceptanceDecision)
s.emitRunScoped(ctx, EventAgentDone, &state, turnOutput.assistant)
- s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun)
+ s.triggerMemoExtraction(state.session.ID, runBoundaryMessagesForMemo(&state), state.rememberedThisRun)
return nil
case acceptance.AcceptanceContinue:
state.lastAcceptanceBlockedReason = strings.TrimSpace(acceptanceDecision.CompletionBlockedReason)
diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go
index f1db6601..a8f72459 100644
--- a/internal/runtime/runtime.go
+++ b/internal/runtime/runtime.go
@@ -145,7 +145,7 @@ type ProviderFactory interface {
// MemoExtractor 定义 runtime 层调用记忆提取的最小能力。
// 通过接口注入避免 runtime 直接依赖 memo 子系统实现细节。
type MemoExtractor interface {
- // Schedule 从消息中安排一次后台记忆提取,失败由实现方自行处理。
+ // Schedule 从当前 run 边界内的消息安排一次后台记忆提取,失败由实现方自行处理。
Schedule(sessionID string, messages []providertypes.Message)
}
diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go
index 9a38cc91..5e1ca56d 100644
--- a/internal/runtime/runtime_test.go
+++ b/internal/runtime/runtime_test.go
@@ -1180,6 +1180,52 @@ func TestServiceRunSchedulesMemoExtractionAfterFinalReply(t *testing.T) {
}
}
+func TestServiceRunSchedulesMemoExtractionFromCurrentRunBoundary(t *testing.T) {
+ t.Parallel()
+
+ manager := newRuntimeConfigManager(t)
+ store := newMemoryStore()
+ session := agentsession.New("existing")
+ session.ID = "session-existing-memo"
+ session.Messages = []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old user")}},
+ {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("old assistant")}},
+ }
+ store.sessions[session.ID] = cloneSession(session)
+
+ scripted := &scriptedProvider{
+ streams: [][]providertypes.StreamEvent{
+ {
+ providertypes.NewTextDeltaStreamEvent("new final"),
+ providertypes.NewMessageDoneStreamEvent("stop", nil),
+ },
+ },
+ }
+ service := NewWithFactory(manager, tools.NewRegistry(), store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{})
+ memoExtractor := &stubScheduledMemoExtractor{}
+ service.SetMemoExtractor(memoExtractor)
+
+ err := service.Run(context.Background(), UserInput{
+ SessionID: session.ID,
+ RunID: "run-memo-boundary",
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart("new user")},
+ })
+ if err != nil {
+ t.Fatalf("Run() error = %v", err)
+ }
+ if len(memoExtractor.calls) != 1 {
+ t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls))
+ }
+ messages := memoExtractor.calls[0].messages
+ if len(messages) != 2 {
+ t.Fatalf("scheduled messages = %#v, want current run user+assistant only", messages)
+ }
+ if renderPartsForVerification(messages[0].Parts) != "new user" ||
+ renderPartsForVerification(messages[1].Parts) != "new final" {
+ t.Fatalf("scheduled messages crossed run boundary: %#v", messages)
+ }
+}
+
func TestServiceRunSkipsAutoMemoExtractionAfterRememberTool(t *testing.T) {
t.Parallel()
@@ -5555,6 +5601,8 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) {
}
service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder)
+ memoExtractor := &stubScheduledMemoExtractor{}
+ service.SetMemoExtractor(memoExtractor)
service.compactRunner = &stubCompactRunner{
result: contextcompact.Result{
Messages: []providertypes.Message{
@@ -5617,6 +5665,16 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) {
if renderPartsForTest(saved.Messages[2].Parts) != "recovered" {
t.Fatalf("expected final assistant reply %q, got %q", "recovered", renderPartsForTest(saved.Messages[2].Parts))
}
+ if len(memoExtractor.calls) != 1 {
+ t.Fatalf("memo schedule calls = %d, want 1", len(memoExtractor.calls))
+ }
+ memoMessages := memoExtractor.calls[0].messages
+ if len(memoMessages) != 2 {
+ t.Fatalf("memo messages = %+v, want current run user+assistant only after compact", memoMessages)
+ }
+ if renderPartsForTest(memoMessages[0].Parts) != "continue" || renderPartsForTest(memoMessages[1].Parts) != "recovered" {
+ t.Fatalf("memo messages crossed compact boundary: %+v", memoMessages)
+ }
events := collectRuntimeEvents(service.Events())
assertEventSequence(t, events, []EventType{
diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go
index 26d3d9bb..9f00db6f 100644
--- a/internal/runtime/session_mutation.go
+++ b/internal/runtime/session_mutation.go
@@ -29,6 +29,7 @@ func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState,
}); err != nil {
return err
}
+ appendMemoRunMessage(state, message)
s.emitRunScoped(ctx, EventUserMessage, state, message)
return nil
}
@@ -87,7 +88,7 @@ func (s *Service) appendAssistantMessageOnlyAndSave(
}
state.session.Messages = append(state.session.Messages, assistant)
state.touchSession()
- return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
+ if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
SessionID: state.session.ID,
Messages: []providertypes.Message{assistant},
UpdatedAt: state.session.UpdatedAt,
@@ -95,7 +96,11 @@ func (s *Service) appendAssistantMessageOnlyAndSave(
Model: state.session.Model,
Workdir: state.session.Workdir,
HasUnknownUsage: state.session.HasUnknownUsage,
- })
+ }); err != nil {
+ return err
+ }
+ appendMemoRunMessage(state, assistant)
+ return nil
}
// appendSystemMessageAndSave 将系统提醒消息追加到会话并持久化。
@@ -111,7 +116,7 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat
}
state.session.Messages = append(state.session.Messages, message)
state.touchSession()
- return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
+ if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{
SessionID: state.session.ID,
Messages: []providertypes.Message{message},
UpdatedAt: state.session.UpdatedAt,
@@ -119,7 +124,11 @@ func (s *Service) appendSystemMessageAndSave(ctx context.Context, state *runStat
Model: state.session.Model,
Workdir: state.session.Workdir,
HasUnknownUsage: state.session.HasUnknownUsage,
- })
+ }); err != nil {
+ return err
+ }
+ appendMemoRunMessage(state, message)
+ return nil
}
// appendToolMessageAndSave 将工具原始结果写回会话,持久化时仅追加一条 tool message。
@@ -143,7 +152,11 @@ func (s *Service) appendToolMessageAndSave(
HasUnknownUsage: state.session.HasUnknownUsage,
}
state.mu.Unlock()
- return s.sessionStore.AppendMessages(ctx, input)
+ if err := s.sessionStore.AppendMessages(ctx, input); err != nil {
+ return err
+ }
+ appendMemoRunMessage(state, toolMessage)
+ return nil
}
// normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。
diff --git a/internal/runtime/state.go b/internal/runtime/state.go
index 2440bddf..277d9c6e 100644
--- a/internal/runtime/state.go
+++ b/internal/runtime/state.go
@@ -4,6 +4,7 @@ import (
"sync"
"time"
+ providertypes "neo-code/internal/provider/types"
"neo-code/internal/runtime/controlplane"
"neo-code/internal/runtime/decider"
runtimefacts "neo-code/internal/runtime/facts"
@@ -20,6 +21,7 @@ type runState struct {
effectiveWorkdir string
compactCount int
reactiveCompactAttempts int
+ memoRunMessages []providertypes.Message
rememberedThisRun bool
planningEnabled bool
taskID string
From ec23d015dc70cc85d92265d66e67bf17aae34cd0 Mon Sep 17 00:00:00 2001
From: Yumiue <229866007@qq.com>
Date: Sun, 10 May 2026 10:47:56 +0800
Subject: [PATCH 4/5] =?UTF-8?q?fix:web=E7=AB=AFslash=E6=8C=87=E4=BB=A4?=
=?UTF-8?q?=E4=BF=AE=E5=A4=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
web/src/components/chat/ChatInput.test.tsx | 69 ++++++++++++++++++++++
web/src/components/chat/ChatInput.tsx | 25 ++++----
2 files changed, 79 insertions(+), 15 deletions(-)
diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx
index 8c610e7b..bf39b1b2 100644
--- a/web/src/components/chat/ChatInput.test.tsx
+++ b/web/src/components/chat/ChatInput.test.tsx
@@ -25,6 +25,12 @@ vi.mock('./ModelSelector', () => ({
default: () => ,
}))
+async function submitSlashCommand(command: string) {
+ const textarea = screen.getByRole('textbox') as HTMLTextAreaElement
+ fireEvent.change(textarea, { target: { value: command } })
+ fireEvent.keyDown(textarea, { key: 'Enter' })
+}
+
describe('ChatInput', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -134,4 +140,67 @@ describe('ChatInput', () => {
expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument()
expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument()
})
+ it('executes /memo without session id and shows payload.Content', async () => {
+ mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({
+ payload: {
+ Content: 'User Memo:\n- [user] coding preference',
+ },
+ })
+ render()
+
+ await submitSlashCommand('/memo')
+
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).toHaveBeenCalledWith('', '', 'memo_list', {})
+ })
+ await waitFor(() => {
+ expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content.includes('coding preference'))).toBe(true)
+ })
+ })
+
+ it('uses fallback text when memo payload has no content field', async () => {
+ mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ payload: {} })
+ render()
+
+ await submitSlashCommand('/memo')
+
+ await waitFor(() => {
+ expect(useChatStore.getState().messages.some((msg) => msg.type === 'system' && msg.content === 'Memo query complete')).toBe(true)
+ })
+ })
+
+ it('executes /remember and /forget without session id', async () => {
+ mockGatewayAPI.executeSystemTool
+ .mockResolvedValueOnce({ payload: { Content: 'Memory saved: [user] keep tests strict' } })
+ .mockResolvedValueOnce({ payload: { Content: 'Removed 1 memo(s) matching \"strict\".' } })
+ render()
+
+ await submitSlashCommand('/remember keep tests strict')
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(1, '', '', 'memo_remember', {
+ type: 'user',
+ title: 'keep tests strict',
+ content: 'keep tests strict',
+ })
+ })
+
+ await submitSlashCommand('/forget strict')
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).toHaveBeenNthCalledWith(2, '', '', 'memo_remove', {
+ keyword: 'strict',
+ scope: 'all',
+ })
+ })
+ })
+
+ it('keeps argument validation for /remember and /forget', async () => {
+ render()
+
+ await submitSlashCommand('/remember')
+ await submitSlashCommand('/forget')
+
+ await waitFor(() => {
+ expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled()
+ })
+ })
})
diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx
index ff4258e4..d5a4bb2e 100644
--- a/web/src/components/chat/ChatInput.tsx
+++ b/web/src/components/chat/ChatInput.tsx
@@ -57,6 +57,13 @@ function buildSlashHelpText(commands: AnySlashCommand[]): string {
return ['可用命令:', ...lines].join('\n')
}
+/** 统一提取系统工具返回文本,兼容 payload.content 与 payload.Content。 */
+function extractSystemToolContent(result: unknown, fallback: string): string {
+ const payload = (result as { payload?: { content?: string; Content?: string } } | null)?.payload
+ const content = payload?.content ?? payload?.Content
+ return content || fallback
+}
+
export default function ChatInput() {
const gatewayAPI = useGatewayAPI()
const text = useComposerStore((state) => state.composerText)
@@ -149,13 +156,9 @@ export default function ChatInput() {
return true
}
case '/memo': {
- if (!isValidSessionId(currentSessionId)) {
- useUIStore.getState().showToast('Send a message first to start a session', 'error')
- return true
- }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {})
- addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete')
+ addSystemMessage(extractSystemToolContent(result, 'Memo query complete'))
} catch (err) {
console.error('Memo list failed:', err)
useUIStore.getState().showToast('Failed to query memo', 'error')
@@ -167,17 +170,13 @@ export default function ChatInput() {
useUIStore.getState().showToast('Usage: /remember ', 'error')
return true
}
- if (!isValidSessionId(currentSessionId)) {
- useUIStore.getState().showToast('Send a message first to start a session', 'error')
- return true
- }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', {
type: 'user',
title: argument,
content: argument,
})
- addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved')
+ addSystemMessage(extractSystemToolContent(result, 'Memo saved'))
} catch (err) {
console.error('Remember failed:', err)
useUIStore.getState().showToast('Failed to save memo', 'error')
@@ -189,16 +188,12 @@ export default function ChatInput() {
useUIStore.getState().showToast('Usage: /forget ', 'error')
return true
}
- if (!isValidSessionId(currentSessionId)) {
- useUIStore.getState().showToast('Send a message first to start a session', 'error')
- return true
- }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', {
keyword: argument,
scope: 'all',
})
- addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted')
+ addSystemMessage(extractSystemToolContent(result, 'Memo deleted'))
} catch (err) {
console.error('Forget failed:', err)
useUIStore.getState().showToast('Failed to delete memo', 'error')
From 0a4e22765b138878e1d2c146e6cc948db5c64fed Mon Sep 17 00:00:00 2001
From: Yumiue <229866007@qq.com>
Date: Sun, 10 May 2026 16:16:15 +0800
Subject: [PATCH 5/5] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=8F=90=E5=8F=96?=
=?UTF-8?q?=E7=9A=84=E6=96=87=E6=9C=AC=E5=86=85=E5=AE=B9=EF=BC=8CAutoExtra?=
=?UTF-8?q?ctor=20=E5=AF=B9=E6=AF=8F=E6=9D=A1=E6=96=B0=E5=80=99=E9=80=89?=
=?UTF-8?q?=E5=8D=95=E7=8B=AC=E5=81=9A=E5=8E=BB=E9=87=8D=E5=86=B3=E7=AD=96?=
=?UTF-8?q?=EF=BC=8C=E7=A7=BB=E9=99=A4extract=5Frecent=5Fmessages?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
internal/app/bootstrap.go | 2 +-
internal/app/bootstrap_test.go | 1 -
internal/config/config_test.go | 46 +--
internal/config/loader.go | 63 +++-
internal/config/loader_test.go | 34 +-
internal/config/memo.go | 35 +-
internal/context/projection.go | 2 +-
internal/context/projection_test.go | 6 +
internal/memo/auto_extractor.go | 113 ++-----
internal/memo/auto_extractor_test.go | 41 ++-
internal/memo/llm_extractor.go | 220 ++++++++----
internal/memo/llm_extractor_test.go | 484 +++++++++------------------
internal/memo/semantic_candidates.go | 321 ++++++++++++++++++
internal/memo/service.go | 40 ++-
internal/memo/service_test.go | 91 ++++-
internal/memo/types.go | 23 +-
internal/tools/memo/remember_test.go | 7 +-
17 files changed, 948 insertions(+), 581 deletions(-)
create mode 100644 internal/memo/semantic_candidates.go
diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go
index f7c6ce5f..41bea565 100644
--- a/internal/app/bootstrap.go
+++ b/internal/app/bootstrap.go
@@ -115,7 +115,7 @@ func newMemoExtractorAdapter(
})
})
- scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator, cfg.Memo.ExtractRecentMessages))
+ scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator))
})
}
diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go
index 917823cb..a1af4c49 100644
--- a/internal/app/bootstrap_test.go
+++ b/internal/app/bootstrap_test.go
@@ -1679,7 +1679,6 @@ func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) {
t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token")
cfg := config.StaticDefaults().Clone()
cfg.SelectedProvider = config.OpenAIName
- cfg.Memo.ExtractRecentMessages = 3
manager := config.NewManager(config.NewLoader("", &cfg))
providerStub := &stubMemoProvider{
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 4e5b8c39..149307f9 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -1495,12 +1495,11 @@ func TestMemoConfigClone(t *testing.T) {
t.Parallel()
original := MemoConfig{
- Enabled: true,
- AutoExtract: false,
- MaxEntries: 100,
- MaxIndexBytes: 2048,
- ExtractTimeoutSec: 9,
- ExtractRecentMessages: 3,
+ Enabled: true,
+ AutoExtract: false,
+ MaxEntries: 100,
+ MaxIndexBytes: 2048,
+ ExtractTimeoutSec: 9,
}
cloned := original.Clone()
if cloned != original {
@@ -1518,10 +1517,9 @@ func TestMemoConfigApplyDefaults(t *testing.T) {
t.Run("fills zero fields", func(t *testing.T) {
cfg := MemoConfig{}
cfg.ApplyDefaults(MemoConfig{
- MaxEntries: DefaultMemoMaxEntries,
- MaxIndexBytes: DefaultMemoMaxIndexBytes,
- ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
- ExtractRecentMessages: DefaultMemoExtractRecentMessage,
+ MaxEntries: DefaultMemoMaxEntries,
+ MaxIndexBytes: DefaultMemoMaxIndexBytes,
+ ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
})
if cfg.MaxEntries != DefaultMemoMaxEntries {
t.Errorf("MaxEntries = %d, want %d", cfg.MaxEntries, DefaultMemoMaxEntries)
@@ -1532,33 +1530,28 @@ func TestMemoConfigApplyDefaults(t *testing.T) {
if cfg.ExtractTimeoutSec != DefaultMemoExtractTimeoutSec {
t.Errorf("ExtractTimeoutSec = %d, want %d", cfg.ExtractTimeoutSec, DefaultMemoExtractTimeoutSec)
}
- if cfg.ExtractRecentMessages != DefaultMemoExtractRecentMessage {
- t.Errorf("ExtractRecentMessages = %d, want %d", cfg.ExtractRecentMessages, DefaultMemoExtractRecentMessage)
- }
})
t.Run("preserves explicit fields", func(t *testing.T) {
cfg := MemoConfig{
- MaxEntries: 50,
- MaxIndexBytes: 1024,
- ExtractTimeoutSec: 30,
- ExtractRecentMessages: 5,
+ MaxEntries: 50,
+ MaxIndexBytes: 1024,
+ ExtractTimeoutSec: 30,
}
cfg.ApplyDefaults(defaultMemoConfig())
- if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 || cfg.ExtractRecentMessages != 5 {
+ if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 {
t.Fatalf("ApplyDefaults() unexpectedly overwrote explicit values: %+v", cfg)
}
})
t.Run("preserves negative fields for validation", func(t *testing.T) {
cfg := MemoConfig{
- MaxEntries: -1,
- MaxIndexBytes: -2,
- ExtractTimeoutSec: -3,
- ExtractRecentMessages: -4,
+ MaxEntries: -1,
+ MaxIndexBytes: -2,
+ ExtractTimeoutSec: -3,
}
cfg.ApplyDefaults(defaultMemoConfig())
- if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 || cfg.ExtractRecentMessages != -4 {
+ if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 {
t.Fatalf("ApplyDefaults() unexpectedly rewrote invalid values: %+v", cfg)
}
})
@@ -1603,13 +1596,6 @@ func TestMemoConfigValidate(t *testing.T) {
}
})
- t.Run("non-positive ExtractRecentMessages", func(t *testing.T) {
- cfg := defaultMemoConfig()
- cfg.ExtractRecentMessages = 0
- if err := cfg.Validate(); err == nil {
- t.Fatal("non-positive ExtractRecentMessages should fail validation")
- }
- })
}
func TestNormalizeWorkdirEdgeCases(t *testing.T) {
diff --git a/internal/config/loader.go b/internal/config/loader.go
index 601d214c..a07a689c 100644
--- a/internal/config/loader.go
+++ b/internal/config/loader.go
@@ -66,12 +66,11 @@ type persistedAskConfig struct {
}
type persistedMemoConfig struct {
- Enabled *bool `yaml:"enabled,omitempty"`
- AutoExtract *bool `yaml:"auto_extract,omitempty"`
- MaxEntries *int `yaml:"max_entries,omitempty"`
- MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"`
- ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"`
- ExtractRecentMessages *int `yaml:"extract_recent_messages,omitempty"`
+ Enabled *bool `yaml:"enabled,omitempty"`
+ AutoExtract *bool `yaml:"auto_extract,omitempty"`
+ MaxEntries *int `yaml:"max_entries,omitempty"`
+ MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"`
+ ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"`
}
func NewLoader(baseDir string, defaults *Config) *Loader {
@@ -225,6 +224,9 @@ func parseConfigWithContextDefaults(
}
func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults MemoConfig) (*Config, error) {
+ if err := rejectRemovedMemoFields(data); err != nil {
+ return nil, err
+ }
var file persistedConfig
decoder := yaml.NewDecoder(bytes.NewReader(data))
decoder.KnownFields(true)
@@ -384,14 +386,12 @@ func newPersistedMemoConfig(cfg MemoConfig) persistedMemoConfig {
maxEntries := cfg.MaxEntries
maxIndexBytes := cfg.MaxIndexBytes
extractTimeoutSec := cfg.ExtractTimeoutSec
- extractRecentMessages := cfg.ExtractRecentMessages
return persistedMemoConfig{
- Enabled: &enabled,
- AutoExtract: &autoExtract,
- MaxEntries: &maxEntries,
- MaxIndexBytes: &maxIndexBytes,
- ExtractTimeoutSec: &extractTimeoutSec,
- ExtractRecentMessages: &extractRecentMessages,
+ Enabled: &enabled,
+ AutoExtract: &autoExtract,
+ MaxEntries: &maxEntries,
+ MaxIndexBytes: &maxIndexBytes,
+ ExtractTimeoutSec: &extractTimeoutSec,
}
}
@@ -413,12 +413,43 @@ func fromPersistedMemoConfig(file persistedMemoConfig, defaults MemoConfig) Memo
if file.ExtractTimeoutSec != nil {
out.ExtractTimeoutSec = *file.ExtractTimeoutSec
}
- if file.ExtractRecentMessages != nil {
- out.ExtractRecentMessages = *file.ExtractRecentMessages
- }
return out
}
+// rejectRemovedMemoFields 在 strict decode 前拦截已删除的 memo 字段,输出明确迁移提示。
+func rejectRemovedMemoFields(data []byte) error {
+ var root yaml.Node
+ if err := yaml.Unmarshal(data, &root); err != nil {
+ return err
+ }
+ if len(root.Content) == 0 {
+ return nil
+ }
+ doc := root.Content[0]
+ if doc.Kind != yaml.MappingNode {
+ return nil
+ }
+
+ for i := 0; i < len(doc.Content); i += 2 {
+ if strings.TrimSpace(doc.Content[i].Value) != "memo" {
+ continue
+ }
+ memoNode := doc.Content[i+1]
+ if memoNode.Kind != yaml.MappingNode {
+ return nil
+ }
+ for j := 0; j < len(memoNode.Content); j += 2 {
+ if strings.TrimSpace(memoNode.Content[j].Value) == "extract_recent_messages" {
+ return fmt.Errorf(
+ "config: memo.extract_recent_messages has been removed; memory extraction now always uses the full run boundary",
+ )
+ }
+ }
+ return nil
+ }
+ return nil
+}
+
// normalizeVerificationSchemaContent 在内存中预处理 verification schema,避免旧字段先于 strict decode 触发硬失败。
func normalizeVerificationSchemaContent(raw []byte) ([]byte, bool, error) {
if len(bytes.TrimSpace(raw)) == 0 {
diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go
index 77cec06f..65d14c6f 100644
--- a/internal/config/loader_test.go
+++ b/internal/config/loader_test.go
@@ -1930,7 +1930,6 @@ memo:
max_entries: 123
max_index_bytes: 4096
extract_timeout_sec: 9
- extract_recent_messages: 4
`
writeLoaderConfig(t, loader, raw)
@@ -1953,9 +1952,6 @@ memo:
if cfg.Memo.ExtractTimeoutSec != 9 {
t.Fatalf("expected memo.extract_timeout_sec=9, got %d", cfg.Memo.ExtractTimeoutSec)
}
- if cfg.Memo.ExtractRecentMessages != 4 {
- t.Fatalf("expected memo.extract_recent_messages=4, got %d", cfg.Memo.ExtractRecentMessages)
- }
data, err := os.ReadFile(loader.ConfigPath())
if err != nil {
@@ -2001,9 +1997,6 @@ shell: powershell
if cfg.Memo.ExtractTimeoutSec <= 0 {
t.Fatalf("expected memo.extract_timeout_sec to be defaulted, got %d", cfg.Memo.ExtractTimeoutSec)
}
- if cfg.Memo.ExtractRecentMessages <= 0 {
- t.Fatalf("expected memo.extract_recent_messages to be defaulted, got %d", cfg.Memo.ExtractRecentMessages)
- }
}
func TestLoaderRejectsLegacyMemoMaxIndexLinesField(t *testing.T) {
@@ -2028,6 +2021,28 @@ memo:
}
}
+func TestLoaderRejectsRemovedMemoExtractRecentMessagesField(t *testing.T) {
+ t.Parallel()
+
+ loader := NewLoader(t.TempDir(), testDefaultConfig())
+ raw := `
+selected_provider: openai
+current_model: gpt-4.1
+shell: powershell
+memo:
+ extract_recent_messages: 4
+`
+ writeLoaderConfig(t, loader, raw)
+
+ cfg, err := loader.Load(context.Background())
+ if err == nil {
+ t.Fatalf("expected removed memo field to be rejected, cfg=%+v", cfg)
+ }
+ if !strings.Contains(err.Error(), "memo.extract_recent_messages has been removed") {
+ t.Fatalf("expected migration hint for extract_recent_messages, got %v", err)
+ }
+}
+
func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) {
t.Parallel()
@@ -2051,11 +2066,6 @@ func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) {
fieldYAML: "extract_timeout_sec: -1",
errContain: "config: memo: extract_timeout_sec must be greater than 0",
},
- {
- name: "negative extract_recent_messages",
- fieldYAML: "extract_recent_messages: -1",
- errContain: "config: memo: extract_recent_messages must be greater than 0",
- },
}
for _, tt := range tests {
diff --git a/internal/config/memo.go b/internal/config/memo.go
index 9da2e2d3..010c9acf 100644
--- a/internal/config/memo.go
+++ b/internal/config/memo.go
@@ -3,31 +3,28 @@ package config
import "errors"
const (
- DefaultMemoMaxEntries = 200
- DefaultMemoMaxIndexBytes = 16 * 1024
- DefaultMemoExtractTimeoutSec = 15
- DefaultMemoExtractRecentMessage = 10
+ DefaultMemoMaxEntries = 200
+ DefaultMemoMaxIndexBytes = 16 * 1024
+ DefaultMemoExtractTimeoutSec = 15
)
// MemoConfig 控制跨会话持久记忆的行为配置。
type MemoConfig struct {
- Enabled bool `yaml:"enabled,omitempty"`
- AutoExtract bool `yaml:"auto_extract,omitempty"`
- MaxEntries int `yaml:"max_entries,omitempty"`
- MaxIndexBytes int `yaml:"max_index_bytes,omitempty"`
- ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"`
- ExtractRecentMessages int `yaml:"extract_recent_messages,omitempty"`
+ Enabled bool `yaml:"enabled,omitempty"`
+ AutoExtract bool `yaml:"auto_extract,omitempty"`
+ MaxEntries int `yaml:"max_entries,omitempty"`
+ MaxIndexBytes int `yaml:"max_index_bytes,omitempty"`
+ ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"`
}
// defaultMemoConfig 返回跨会话记忆的默认配置。
func defaultMemoConfig() MemoConfig {
return MemoConfig{
- Enabled: true,
- AutoExtract: true,
- MaxEntries: DefaultMemoMaxEntries,
- MaxIndexBytes: DefaultMemoMaxIndexBytes,
- ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
- ExtractRecentMessages: DefaultMemoExtractRecentMessage,
+ Enabled: true,
+ AutoExtract: true,
+ MaxEntries: DefaultMemoMaxEntries,
+ MaxIndexBytes: DefaultMemoMaxIndexBytes,
+ ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
}
}
@@ -50,9 +47,6 @@ func (c *MemoConfig) ApplyDefaults(defaults MemoConfig) {
if c.ExtractTimeoutSec == 0 {
c.ExtractTimeoutSec = defaults.ExtractTimeoutSec
}
- if c.ExtractRecentMessages == 0 {
- c.ExtractRecentMessages = defaults.ExtractRecentMessages
- }
}
// Validate 校验 memo 配置是否合法。
@@ -66,8 +60,5 @@ func (c MemoConfig) Validate() error {
if c.ExtractTimeoutSec <= 0 {
return errors.New("extract_timeout_sec must be greater than 0")
}
- if c.ExtractRecentMessages <= 0 {
- return errors.New("extract_recent_messages must be greater than 0")
- }
return nil
}
diff --git a/internal/context/projection.go b/internal/context/projection.go
index 298f7e8c..265beb0c 100644
--- a/internal/context/projection.go
+++ b/internal/context/projection.go
@@ -88,7 +88,7 @@ func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []pro
keep := make([]bool, len(messages))
for index := 0; index < len(messages); index++ {
message := messages[index]
- if message.Role == providertypes.RoleTool {
+ if message.Role == providertypes.RoleTool || message.Role == providertypes.RoleSystem {
continue
}
diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go
index 3b0c932a..80edb854 100644
--- a/internal/context/projection_test.go
+++ b/internal/context/projection_test.go
@@ -145,6 +145,7 @@ func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T)
messages := []providertypes.Message{
{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}},
+ {Role: providertypes.RoleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart("must call todo_write")}},
{Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan")}},
{
Role: providertypes.RoleAssistant,
@@ -174,6 +175,11 @@ func TestBuildMemoExtractionMessagesForModelKeepsFullRunSafeSpans(t *testing.T)
if renderDisplayParts(projected[0].Parts) != "first" || renderDisplayParts(projected[3].Parts) != "last" {
t.Fatalf("expected full run user messages to remain, got %+v", projected)
}
+ for _, message := range projected {
+ if message.Role == providertypes.RoleSystem {
+ t.Fatalf("system reminder should be excluded from memo extraction window: %+v", projected)
+ }
+ }
if projected[1].Role != providertypes.RoleAssistant || len(projected[1].ToolCalls) != 1 {
t.Fatalf("expected complete assistant tool span, got %+v", projected[1])
}
diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go
index a1ccb43b..812d64e8 100644
--- a/internal/memo/auto_extractor.go
+++ b/internal/memo/auto_extractor.go
@@ -149,7 +149,6 @@ func (a *AutoExtractor) handleDebounce(sessionID string, state *autoExtractState
req := *state.pending
state.pending = nil
- // 增量检测:消息未变化时跳过提取
fp := computeMessageFingerprint(req.messages)
if fp == state.lastFingerprint {
a.armIdleTimerLocked(sessionID, state)
@@ -183,7 +182,7 @@ func (a *AutoExtractor) handleRunDone(sessionID string, state *autoExtractState)
a.armIdleTimerLocked(sessionID, state)
}
-// armIdleTimerLocked 在会话空闲时安排状态回收,避免 map 与 goroutine 长期累积。
+// armIdleTimerLocked 在会话空闲时安排状态回收,避免 map 中 goroutine 长期累积。
func (a *AutoExtractor) armIdleTimerLocked(sessionID string, state *autoExtractState) {
stopTimer(state.idleTimer)
state.idleSeq++
@@ -203,7 +202,6 @@ func (a *AutoExtractor) handleIdle(sessionID string, state *autoExtractState, se
a.mu.Lock()
defer a.mu.Unlock()
- // 在删除前再次确认 map 中仍指向当前状态,防止旧回调回收新状态。
if a.states[sessionID] != state {
return
}
@@ -219,30 +217,11 @@ func isIdleStateLocked(state *autoExtractState, seq uint64) bool {
return state.idleSeq == seq && !state.running && state.pending == nil
}
-// extractAndStore 执行提取,并在写入前做本地批次去重和持久化级别的原子去重。
-// 返回值表示本次提取和写入流程是否成功完成,可用于更新增量指纹。
+// extractAndStore 执行提取,并在写入前完成本地去重和语义决策。
func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []providertypes.Message) bool {
ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout)
defer cancel()
- if decisionExtractor, ok := extractor.(DecisionExtractor); ok {
- existing, err := a.svc.autoExtractionCandidates(ctx)
- if err != nil {
- a.logError("memo: auto extract load candidates failed: %v", err)
- return false
- }
- decisions, err := decisionExtractor.ExtractDecisions(ctx, messages, existing)
- if err != nil {
- if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
- a.logError("memo: auto extract skipped (protocol_mismatch): %v", err)
- return true
- }
- a.logError("memo: auto extract failed: %v", err)
- return false
- }
- return a.applyExtractionDecisions(ctx, decisions)
- }
-
entries, err := extractor.Extract(ctx, messages)
if err != nil {
if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
@@ -256,6 +235,7 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider
return true
}
+ resolver, _ := extractor.(DecisionResolver)
seen := make(map[string]struct{}, len(entries))
succeeded := true
for _, entry := range entries {
@@ -267,60 +247,55 @@ func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []provider
if _, exists := seen[key]; exists {
continue
}
+ seen[key] = struct{}{}
+
+ if resolver == nil {
+ if _, err := a.svc.addAutoExtractIfAbsent(ctx, entry); err != nil {
+ a.logError("memo: auto extract add failed: %v", err)
+ succeeded = false
+ }
+ continue
+ }
- added, err := a.svc.addAutoExtractIfAbsent(ctx, entry)
+ shortlist, err := a.svc.semanticCandidateShortlist(ctx, entry, semanticCandidateShortlistLimit)
if err != nil {
- a.logError("memo: auto extract add failed: %v", err)
+ a.logError("memo: auto extract shortlist failed: %v", err)
succeeded = false
continue
}
-
- seen[key] = struct{}{}
- if !added {
+ if len(shortlist) == 0 {
+ if _, err := a.svc.addAutoExtractIfAbsent(ctx, entry); err != nil {
+ a.logError("memo: auto extract add failed: %v", err)
+ succeeded = false
+ }
continue
}
- }
- return succeeded
-}
-// applyExtractionDecisions 应用模型返回的 create/update/skip 决策,并保留本地精确去重兜底。
-func (a *AutoExtractor) applyExtractionDecisions(ctx context.Context, decisions []ExtractionDecision) bool {
- if len(decisions) == 0 {
- return true
- }
+ decision, err := resolver.ResolveDecision(ctx, entry, shortlist)
+ if err != nil {
+ if errors.Is(err, ErrExtractionNoJSONArray) || errors.Is(err, ErrExtractionIncompleteJSONArray) {
+ a.logError("memo: auto extract skipped (protocol_mismatch): %v", err)
+ return true
+ }
+ a.logError("memo: auto extract resolve failed: %v", err)
+ succeeded = false
+ continue
+ }
- seenCreates := make(map[string]struct{}, len(decisions))
- succeeded := true
- for _, decision := range decisions {
switch decision.Action {
case ExtractionActionCreate:
- entry := decision.Entry
- entry.Source = SourceAutoExtract
- key := autoExtractDedupKey(entry)
- if key == "" {
- continue
- }
- if _, exists := seenCreates[key]; exists {
- continue
- }
- added, err := a.svc.addAutoExtractIfAbsent(ctx, entry)
- if err != nil {
+ createEntry := decision.Entry
+ createEntry.Source = SourceAutoExtract
+ if _, err := a.svc.addAutoExtractIfAbsent(ctx, createEntry); err != nil {
a.logError("memo: auto extract add failed: %v", err)
succeeded = false
- continue
- }
- seenCreates[key] = struct{}{}
- if !added {
- continue
}
case ExtractionActionUpdate:
- entry := decision.Entry
- entry.Source = SourceAutoExtract
- _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, entry)
- if err != nil {
+ updateEntry := decision.Entry
+ updateEntry.Source = SourceAutoExtract
+ if _, err := a.svc.updateAutoExtractIfAllowed(ctx, decision.Ref, updateEntry); err != nil {
a.logError("memo: auto extract update failed: %v", err)
succeeded = false
- continue
}
case ExtractionActionSkip:
continue
@@ -341,21 +316,8 @@ func autoExtractDedupKey(entry Entry) string {
// parseTopicSourceAndContent 从 topic 文件中解析 source frontmatter 和正文内容。
func parseTopicSourceAndContent(topic string) (string, string) {
- parts := strings.Split(topic, "---")
- if len(parts) < 3 {
- return "", strings.TrimSpace(topic)
- }
-
- frontmatter := parts[1]
- body := strings.TrimSpace(strings.Join(parts[2:], "---"))
- for _, line := range strings.Split(frontmatter, "\n") {
- line = strings.TrimSpace(line)
- if !strings.HasPrefix(line, "source:") {
- continue
- }
- return strings.TrimSpace(strings.TrimPrefix(line, "source:")), body
- }
- return "", body
+ source, _, body := parseTopicSnapshot(topic)
+ return source, body
}
// cloneProviderMessages 深拷贝消息切片,避免后台任务读取到后续修改。
@@ -400,7 +362,6 @@ func computeMessageFingerprint(messages []providertypes.Message) uint64 {
return h.Sum64()
}
- // 回退路径仅在意外序列化失败时使用,至少保留文本消息的增量检测能力。
for _, msg := range messages {
_, _ = h.Write([]byte(msg.Role))
_, _ = h.Write([]byte{0})
diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go
index 6c4e5835..06e012c9 100644
--- a/internal/memo/auto_extractor_test.go
+++ b/internal/memo/auto_extractor_test.go
@@ -39,30 +39,38 @@ func (s *stubMemoExtractor) Calls() int {
}
type stubDecisionMemoExtractor struct {
- mu sync.Mutex
- callCount int
- candidates []ExtractionCandidate
- decisions []ExtractionDecision
- err error
+ mu sync.Mutex
+ callCount int
+ candidates []ExtractionCandidate
+ extractEntries []Entry
+ decisions []ExtractionDecision
+ err error
}
func (s *stubDecisionMemoExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) {
- return nil, nil
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return append([]Entry(nil), s.extractEntries...), nil
}
-func (s *stubDecisionMemoExtractor) ExtractDecisions(
+func (s *stubDecisionMemoExtractor) ResolveDecision(
ctx context.Context,
- messages []providertypes.Message,
+ candidate Entry,
existing []ExtractionCandidate,
-) ([]ExtractionDecision, error) {
+) (ExtractionDecision, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.callCount++
s.candidates = append([]ExtractionCandidate(nil), existing...)
if s.err != nil {
- return nil, s.err
+ return ExtractionDecision{}, s.err
}
- return append([]ExtractionDecision(nil), s.decisions...), nil
+ if len(s.decisions) == 0 {
+ return ExtractionDecision{Action: ExtractionActionCreate, Entry: candidate}, nil
+ }
+ decision := s.decisions[0]
+ s.decisions = append([]ExtractionDecision(nil), s.decisions[1:]...)
+ return decision, nil
}
func newAutoExtractorTestService(t *testing.T) *Service {
@@ -278,6 +286,9 @@ func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing
}
extractor := &stubDecisionMemoExtractor{
+ extractEntries: []Entry{
+ {Type: TypeFeedback, Title: "测试策略", Content: "用户要求修改后先跑相关测试。"},
+ },
decisions: []ExtractionDecision{
{
Action: ExtractionActionUpdate,
@@ -324,8 +335,8 @@ func TestAutoExtractorAppliesSemanticUpdateOnlyForAutoExtractedMemory(t *testing
if len(manualRecall) != 1 || strings.Contains(manualRecall[0].Content, "不应覆盖") {
t.Fatalf("manual memory should not be overwritten, got %+v", manualRecall)
}
- if len(extractor.candidates) != 2 {
- t.Fatalf("expected existing candidates to be provided, got %+v", extractor.candidates)
+ if len(extractor.candidates) != 1 || extractor.candidates[0].Ref != autoRef {
+ t.Fatalf("expected shortlist to target the auto-extracted memory, got %+v", extractor.candidates)
}
}
@@ -341,6 +352,10 @@ func TestAutoExtractorSemanticCreateStillUsesExactDedup(t *testing.T) {
}
extractor := &stubDecisionMemoExtractor{
+ extractEntries: []Entry{
+ {Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"},
+ {Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"},
+ },
decisions: []ExtractionDecision{
{Action: ExtractionActionCreate, Entry: Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"}},
{Action: ExtractionActionCreate, Entry: Entry{Type: TypeProject, Title: "新事实", Content: "项目需要语义去重。"}},
diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go
index 621446de..ba239fe5 100644
--- a/internal/memo/llm_extractor.go
+++ b/internal/memo/llm_extractor.go
@@ -13,17 +13,16 @@ import (
)
var (
- // ErrExtractionNoJSONArray 表示提取结果中找不到合法 JSON 数组起始。
- ErrExtractionNoJSONArray = errors.New("memo: extraction response does not contain a JSON array")
- // ErrExtractionIncompleteJSONArray 表示提取结果中的 JSON 数组不完整。
- ErrExtractionIncompleteJSONArray = errors.New("memo: extraction response contains an incomplete JSON array")
+ // ErrExtractionNoJSONArray 表示提取结果中找不到合法 JSON 数组或对象起始。
+ ErrExtractionNoJSONArray = errors.New("memo: extraction response does not contain a JSON payload")
+ // ErrExtractionIncompleteJSONArray 表示提取结果中的 JSON 数组或对象不完整。
+ ErrExtractionIncompleteJSONArray = errors.New("memo: extraction response contains an incomplete JSON payload")
)
-// LLMExtractor 基于 LLM 分析最近对话,并返回结构化记忆条目。
+// LLMExtractor 基于 LLM 分析当前 run 对话,并输出结构化记忆或去重决策。
type LLMExtractor struct {
- generator TextGenerator
- now func() time.Time
- recentMessageLimit int
+ generator TextGenerator
+ now func() time.Time
}
type extractedEntry struct {
@@ -36,38 +35,15 @@ type extractedEntry struct {
}
// NewLLMExtractor 创建基于 TextGenerator 的记忆提取器。
-func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtractor {
- if recentMessageLimit <= 0 {
- recentMessageLimit = 10
- }
+func NewLLMExtractor(generator TextGenerator) *LLMExtractor {
return &LLMExtractor{
- generator: generator,
- now: time.Now,
- recentMessageLimit: recentMessageLimit,
+ generator: generator,
+ now: time.Now,
}
}
// Extract 从当前 run 对话中提取可跨会话持久化的新增记忆条目。
func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error) {
- decisions, err := e.ExtractDecisions(ctx, messages, nil)
- if err != nil {
- return nil, err
- }
- entries := make([]Entry, 0, len(decisions))
- for _, decision := range decisions {
- if decision.Action == ExtractionActionCreate {
- entries = append(entries, decision.Entry)
- }
- }
- return entries, nil
-}
-
-// ExtractDecisions 从当前 run 对话中提取记忆,并结合既有记忆输出新增、合并或跳过决策。
-func (e *LLMExtractor) ExtractDecisions(
- ctx context.Context,
- messages []providertypes.Message,
- existing []ExtractionCandidate,
-) ([]ExtractionDecision, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
@@ -80,7 +56,7 @@ func (e *LLMExtractor) ExtractDecisions(
return nil, nil
}
- response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now(), existing), runMessages)
+ response, err := e.generator.Generate(ctx, buildExtractionPrompt(e.now()), runMessages)
if err != nil {
return nil, err
}
@@ -98,33 +74,63 @@ func (e *LLMExtractor) ExtractDecisions(
return nil, fmt.Errorf("memo: parse extraction response: %w", err)
}
- decisions := make([]ExtractionDecision, 0, len(extracted))
+ entries := make([]Entry, 0, len(extracted))
for _, item := range extracted {
- decision, ok := toExtractionDecision(item)
+ entry, ok := toMemoEntry(item)
if !ok {
continue
}
- decisions = append(decisions, decision)
+ entries = append(entries, entry)
}
- return decisions, nil
+ return entries, nil
}
-// buildExtractionPrompt 构造记忆提取专用的 system prompt。
-func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string {
- currentDate := now.In(time.Local).Format("2006-01-02")
- existingJSON := "[]"
- if len(existing) > 0 {
- if data, err := json.Marshal(existing); err == nil {
- existingJSON = string(data)
- }
+// ResolveDecision 结合单条候选记忆与 shortlist,解析 create/update/skip 决策。
+func (e *LLMExtractor) ResolveDecision(
+ ctx context.Context,
+ candidate Entry,
+ existing []ExtractionCandidate,
+) (ExtractionDecision, error) {
+ if err := ctx.Err(); err != nil {
+ return ExtractionDecision{}, err
+ }
+ if e == nil || e.generator == nil {
+ return ExtractionDecision{}, errors.New("memo: text generator is nil")
+ }
+
+ response, err := e.generator.Generate(ctx, buildResolutionPrompt(e.now(), candidate, existing), buildResolutionMessages(candidate))
+ if err != nil {
+ return ExtractionDecision{}, err
+ }
+ if err := ctx.Err(); err != nil {
+ return ExtractionDecision{}, err
}
+
+ jsonText, err := extractJSONObject(response)
+ if err != nil {
+ return ExtractionDecision{}, err
+ }
+
+ var extracted extractedEntry
+ if err := json.Unmarshal([]byte(jsonText), &extracted); err != nil {
+ return ExtractionDecision{}, fmt.Errorf("memo: parse resolution response: %w", err)
+ }
+ decision, ok := toExtractionDecision(extracted)
+ if !ok {
+ return ExtractionDecision{}, errors.New("memo: resolution response is invalid")
+ }
+ return decision, nil
+}
+
+// buildExtractionPrompt 构造仅负责“当前 run 提取”的 system prompt。
+func buildExtractionPrompt(now time.Time) string {
+ currentDate := now.In(time.Local).Format("2006-01-02")
return strings.TrimSpace(fmt.Sprintf(`
你是一个记忆提取助手(memory extraction assistant)。
-分析当前 run 对话中值得跨会话持久记住的信息,并结合既有记忆完成语义去重,返回严格 JSON 数组。
+分析当前 run 对话中值得跨会话持久记住的信息,并返回严格 JSON 数组。
当前本地日期:%s
如果对话中出现“今天、明天、下周二”等相对日期,必须先转换为绝对日期再写入 content。
-
只允许以下四种 type:
- user: 用户偏好、习惯、背景、专长
- feedback: 用户对 Agent 做法的纠正、要求、确认过的工作方式
@@ -136,18 +142,68 @@ func buildExtractionPrompt(now time.Time, existing []ExtractionCandidate) string
2. 不要提取通用编程知识、代码结构、文件路径、Git 历史。
3. 每条记忆必须具体、可操作。
4. 没有值得记住的信息时,返回 []。
-5. 如果新信息与既有记忆语义相同或只是轻微改写,输出 action="skip"。
-6. 如果新信息能补充或修正既有 source="extractor_auto" 的记忆,输出 action="update" 并填写目标 ref。
-7. 不允许 update source 不是 "extractor_auto" 的既有记忆;这类相近内容只能 skip。
-8. 如果是全新的可持久化信息,输出 action="create"。
-9. 输出必须是 JSON 数组,不要输出任何额外解释。
+5. 输出必须是 JSON 数组,不要输出任何额外解释。
-既有记忆候选(JSON):
+输出格式:
+[{"type":"user","title":"...","content":"...","keywords":["..."]}]
+`, currentDate))
+}
+
+// buildResolutionPrompt 构造单条候选记忆的去重决策提示。
+func buildResolutionPrompt(now time.Time, candidate Entry, existing []ExtractionCandidate) string {
+ currentDate := now.In(time.Local).Format("2006-01-02")
+ candidateJSON := marshalPromptJSON(struct {
+ Type Type `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Keywords []string `json:"keywords,omitempty"`
+ }{
+ Type: candidate.Type,
+ Title: candidate.Title,
+ Content: candidate.Content,
+ Keywords: candidate.Keywords,
+ })
+ existingJSON := marshalPromptJSON(existing)
+ return strings.TrimSpace(fmt.Sprintf(`
+你是一个记忆去重决策助手(memory dedupe assistant)。
+分析一条新的候选记忆与已有记忆 shortlist,返回严格 JSON 对象。
+
+当前本地日期:%s
+如果候选内容中的日期是相对日期,先换算成绝对日期再判断。
+
+规则:
+1. 若候选记忆与 shortlist 中某条记忆语义相同,返回 {"action":"skip","ref":"..."}。
+2. 若候选记忆补充或修正了某条 source="extractor_auto" 的已有记忆,返回 {"action":"update","ref":"...","title":"...","content":"...","keywords":[...]}。
+3. 不允许更新 source 不是 "extractor_auto" 的已有记忆;遇到这类情况只能返回 skip。
+4. 若 shortlist 中没有合适对象,返回 {"action":"create","type":"...","title":"...","content":"...","keywords":[...]}。
+5. 输出必须是单个 JSON 对象,不要输出任何额外解释。
+
+候选记忆(JSON):
%s
-输出格式:
-[{"action":"create","type":"user","title":"...","content":"...","keywords":["..."]},{"action":"update","ref":"project:p.md","title":"...","content":"...","keywords":["..."]},{"action":"skip","ref":"user:u.md"}]
-`, currentDate, existingJSON))
+已有记忆 shortlist(JSON):
+%s
+`, currentDate, candidateJSON, existingJSON))
+}
+
+// buildResolutionMessages 为去重决策构造最小 provider 输入,避免空消息列表触发兼容问题。
+func buildResolutionMessages(candidate Entry) []providertypes.Message {
+ text := fmt.Sprintf("请为这条候选记忆返回 create、update 或 skip 决策:%s", candidate.Title)
+ return []providertypes.Message{
+ {
+ Role: providertypes.RoleUser,
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart(text)},
+ },
+ }
+}
+
+// marshalPromptJSON 将结构化数据压缩为 prompt 内联 JSON。
+func marshalPromptJSON(value any) string {
+ data, err := json.Marshal(value)
+ if err != nil {
+ return "null"
+ }
+ return string(data)
}
// containsUserMessage 检查待提取消息中是否包含用户输入。
@@ -189,9 +245,13 @@ func toExtractionDecision(item extractedEntry) (ExtractionDecision, bool) {
return ExtractionDecision{}, false
}
if action == ExtractionActionSkip {
+ ref := strings.TrimSpace(item.Ref)
+ if ref == "" {
+ return ExtractionDecision{}, false
+ }
return ExtractionDecision{
Action: action,
- Ref: strings.TrimSpace(item.Ref),
+ Ref: ref,
}, true
}
@@ -312,3 +372,45 @@ func extractJSONArray(text string) (string, error) {
return "", ErrExtractionIncompleteJSONArray
}
+
+// extractJSONObject 从模型返回文本中提取最外层 JSON 对象。
+func extractJSONObject(text string) (string, error) {
+ start := strings.Index(text, "{")
+ if start < 0 {
+ return "", ErrExtractionNoJSONArray
+ }
+
+ depth := 0
+ inString := false
+ escaped := false
+ for index := start; index < len(text); index++ {
+ ch := text[index]
+ if inString {
+ if escaped {
+ escaped = false
+ continue
+ }
+ switch ch {
+ case '\\':
+ escaped = true
+ case '"':
+ inString = false
+ }
+ continue
+ }
+
+ switch ch {
+ case '"':
+ inString = true
+ case '{':
+ depth++
+ case '}':
+ depth--
+ if depth == 0 {
+ return strings.TrimSpace(text[start : index+1]), nil
+ }
+ }
+ }
+
+ return "", ErrExtractionIncompleteJSONArray
+}
diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go
index 06813229..987c59fb 100644
--- a/internal/memo/llm_extractor_test.go
+++ b/internal/memo/llm_extractor_test.go
@@ -32,12 +32,12 @@ func (s *stubTextGenerator) Generate(
return s.response, nil
}
-// TestLLMExtractorExtractValidJSON 验证提取器可以解析合法 JSON 并收敛字段。
+// TestLLMExtractorExtractValidJSON 验证提取器可解析合法 JSON 并收敛字段。
func TestLLMExtractorExtractValidJSON(t *testing.T) {
generator := &stubTextGenerator{
- response: `[{"type":"user","title":" 偏好 Go 代码风格 ","content":"用户偏好使用 Go 惯用写法。","keywords":["go"," style ","go"]}]`,
+ response: `[{"type":"user","title":" 偏好 Go 代码风格 ","content":"用户偏好使用 Go 惯用写法。","keywords":["go"," style ","go"]}]`,
}
- extractor := NewLLMExtractor(generator, 10)
+ extractor := NewLLMExtractor(generator)
extractor.now = func() time.Time {
return time.Date(2026, 4, 13, 10, 0, 0, 0, time.FixedZone("CST", 8*3600))
}
@@ -70,219 +70,100 @@ func TestLLMExtractorExtractValidJSON(t *testing.T) {
if generator.calls != 1 {
t.Fatalf("Generate() calls = %d, want 1", generator.calls)
}
- if !strings.Contains(generator.prompt, "user: 用户偏好") {
+ if !strings.Contains(generator.prompt, "用户偏好") {
t.Fatalf("prompt should describe user type, got %q", generator.prompt)
}
- if !strings.Contains(generator.prompt, "当前本地日期:2026-04-13") {
+ if !strings.Contains(generator.prompt, "2026-04-13") {
t.Fatalf("prompt should include absolute local date, got %q", generator.prompt)
}
- if !strings.Contains(generator.prompt, "必须先转换为绝对日期") {
- t.Fatalf("prompt should require absolute dates, got %q", generator.prompt)
- }
-}
-
-// TestLLMExtractorExtractEmptyResult 验证空数组响应会返回零条记忆。
-func TestLLMExtractorExtractEmptyResult(t *testing.T) {
- extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}, 10)
-
- entries, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("这轮没有需要记住的内容。")}},
- })
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(entries) != 0 {
- t.Fatalf("len(entries) = %d, want 0", len(entries))
- }
-}
-
-// TestLLMExtractorExtractNoUserMessage 验证没有用户消息时不会调用模型。
-func TestLLMExtractorExtractNoUserMessage(t *testing.T) {
- generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
-
- entries, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("只有助手消息。")}},
- })
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(entries) != 0 {
- t.Fatalf("len(entries) = %d, want 0", len(entries))
- }
- if generator.calls != 0 {
- t.Fatalf("Generate() calls = %d, want 0", generator.calls)
- }
-}
-
-// TestLLMExtractorExtractNoMessages 验证空消息输入直接返回空结果。
-func TestLLMExtractorExtractNoMessages(t *testing.T) {
- extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}, 10)
-
- entries, err := extractor.Extract(context.Background(), nil)
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(entries) != 0 {
- t.Fatalf("len(entries) = %d, want 0", len(entries))
- }
-}
-
-// TestLLMExtractorExtractInvalidJSON 验证无效 JSON 会返回错误。
-func TestLLMExtractorExtractInvalidJSON(t *testing.T) {
- extractor := NewLLMExtractor(&stubTextGenerator{response: `[{invalid json}]`}, 10)
-
- _, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
- })
- if err == nil {
- t.Fatal("expected invalid JSON error")
- }
-}
-
-// TestLLMExtractorExtractToleratesWrappedJSON 验证 JSON 前后噪声不会影响解析。
-func TestLLMExtractorExtractToleratesWrappedJSON(t *testing.T) {
- extractor := NewLLMExtractor(&stubTextGenerator{
- response: "分析如下:\n[{\"type\":\"feedback\",\"title\":\"以后先跑测试\",\"content\":\"用户要求修改后先跑测试。\"}]\n以上完毕。",
- }, 10)
-
- entries, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑测试。")}},
- })
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(entries) != 1 || entries[0].Type != TypeFeedback {
- t.Fatalf("entries = %#v", entries)
- }
}
-// TestLLMExtractorExtractFiltersInvalidEntries 验证非法类型和空字段会被过滤。
-func TestLLMExtractorExtractFiltersInvalidEntries(t *testing.T) {
- extractor := NewLLMExtractor(&stubTextGenerator{
- response: `[
- {"type":"invalid","title":"bad","content":"bad"},
- {"type":"project","title":" ","content":"missing title"},
- {"type":"reference","title":"文档入口","content":"查看 docs/runtime-provider-event-flow.md"}
- ]`,
- }, 10)
-
- entries, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("参考文档在 docs/runtime-provider-event-flow.md。")}},
- })
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(entries) != 1 || entries[0].Type != TypeReference {
- t.Fatalf("entries = %#v", entries)
+// TestLLMExtractorExtractSkipsEmptyOrNonUserInputs 验证缺少有效用户输入时不会调用模型。
+func TestLLMExtractorExtractSkipsEmptyOrNonUserInputs(t *testing.T) {
+ tests := []struct {
+ name string
+ messages []providertypes.Message
+ }{
+ {
+ name: "nil messages",
+ messages: nil,
+ },
+ {
+ name: "assistant only",
+ messages: []providertypes.Message{
+ {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("只有助手消息。")}},
+ },
+ },
+ {
+ name: "image only user",
+ messages: []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewRemoteImagePart("https://example.com/pic.png")}},
+ },
+ },
}
-}
-
-// TestLLMExtractorExtractCancelledContext 验证已取消上下文会中止提取。
-func TestLLMExtractorExtractCancelledContext(t *testing.T) {
- generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- _, err := extractor.Extract(ctx, []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
- })
- if !errors.Is(err, context.Canceled) {
- t.Fatalf("Extract() error = %v, want context.Canceled", err)
- }
- if generator.calls != 0 {
- t.Fatalf("Generate() calls = %d, want 0", generator.calls)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ generator := &stubTextGenerator{response: `[]`}
+ extractor := NewLLMExtractor(generator)
+ entries, err := extractor.Extract(context.Background(), tt.messages)
+ if err != nil {
+ t.Fatalf("Extract() error = %v", err)
+ }
+ if len(entries) != 0 {
+ t.Fatalf("len(entries) = %d, want 0", len(entries))
+ }
+ if generator.calls != 0 {
+ t.Fatalf("Generate() calls = %d, want 0", generator.calls)
+ }
+ })
}
}
-// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息而非固定 recent window。
+// TestLLMExtractorExtractUsesFullRunMessages 验证提取器使用完整 run 消息并排除 system/tool 噪声。
func TestLLMExtractorExtractUsesFullRunMessages(t *testing.T) {
generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
+ extractor := NewLLMExtractor(generator)
- messages := make([]providertypes.Message, 0, 16)
- for index := 0; index < 12; index++ {
- messages = append(messages, providertypes.Message{
- Role: providertypes.RoleUser,
- Parts: []providertypes.ContentPart{providertypes.NewTextPart("user-" + string(rune('a'+index)))},
- })
- if index%3 == 0 {
- messages = append(messages, providertypes.Message{
- Role: providertypes.RoleTool,
- Parts: []providertypes.ContentPart{providertypes.NewTextPart("tool-" + string(rune('a'+index)))},
- })
- }
+ messages := []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}},
+ {Role: providertypes.RoleSystem, Parts: []providertypes.ContentPart{providertypes.NewTextPart("must call todo_write")}},
+ {
+ Role: providertypes.RoleAssistant,
+ ToolCalls: []providertypes.ToolCall{
+ {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
+ },
+ },
+ {
+ Role: providertypes.RoleTool,
+ ToolCallID: "call_1",
+ Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")},
+ ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"},
+ },
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("last")}},
}
_, err := extractor.Extract(context.Background(), messages)
if err != nil {
t.Fatalf("Extract() error = %v", err)
}
- if len(generator.messages) != 12 {
- t.Fatalf("len(generator.messages) = %d, want 12", len(generator.messages))
+ if len(generator.messages) != 4 {
+ t.Fatalf("len(generator.messages) = %d, want 4", len(generator.messages))
}
for _, message := range generator.messages {
- if message.Role == providertypes.RoleTool {
- t.Fatalf("unexpected tool message in extraction context: %#v", message)
+ if message.Role == providertypes.RoleSystem {
+ t.Fatalf("system reminder should not enter extraction window: %#v", message)
}
}
- if renderMemoParts(generator.messages[0].Parts) != "user-a" ||
- renderMemoParts(generator.messages[11].Parts) != "user-l" {
- t.Fatalf("unexpected run window: first=%q last=%q",
- renderMemoParts(generator.messages[0].Parts),
- renderMemoParts(generator.messages[11].Parts))
- }
-}
-
-func TestLLMExtractorExtractDecisionsIncludesExistingCandidates(t *testing.T) {
- generator := &stubTextGenerator{
- response: `[{"action":"skip","ref":"user:u.md"},{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}]`,
- }
- extractor := NewLLMExtractor(generator, 10)
-
- decisions, err := extractor.ExtractDecisions(
- context.Background(),
- []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑相关测试。")}},
- },
- []ExtractionCandidate{
- {
- Ref: "project:p.md",
- Scope: ScopeProject,
- Type: TypeFeedback,
- Source: SourceAutoExtract,
- Title: "测试策略",
- Content: "用户要求修改后先跑测试。",
- },
- },
- )
- if err != nil {
- t.Fatalf("ExtractDecisions() error = %v", err)
- }
- if len(decisions) != 2 {
- t.Fatalf("len(decisions) = %d, want 2", len(decisions))
- }
- if decisions[0].Action != ExtractionActionSkip || decisions[0].Ref != "user:u.md" {
- t.Fatalf("unexpected skip decision: %+v", decisions[0])
- }
- if decisions[1].Action != ExtractionActionUpdate || decisions[1].Ref != "project:p.md" {
- t.Fatalf("unexpected update decision: %+v", decisions[1])
- }
- if decisions[1].Entry.Type != "" {
- t.Fatalf("update decision should not require type, got %+v", decisions[1].Entry)
- }
- if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) {
- t.Fatalf("prompt should include existing memory candidates, got %q", generator.prompt)
- }
- if !strings.Contains(generator.prompt, `action="update"`) || !strings.Contains(generator.prompt, `source="extractor_auto"`) {
- t.Fatalf("prompt should describe semantic dedupe protocol, got %q", generator.prompt)
+ if renderMemoParts(generator.messages[0].Parts) != "first" || renderMemoParts(generator.messages[3].Parts) != "last" {
+ t.Fatalf("unexpected extraction window: %+v", generator.messages)
}
}
+// TestLLMExtractorExtractDropsIncompleteToolCallSpan 验证不完整的 tool call span 会被剔除。
func TestLLMExtractorExtractDropsIncompleteToolCallSpan(t *testing.T) {
generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
+ extractor := NewLLMExtractor(generator)
_, err := extractor.Extract(context.Background(), []providertypes.Message{
{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}},
@@ -307,169 +188,134 @@ func TestLLMExtractorExtractDropsIncompleteToolCallSpan(t *testing.T) {
}
}
-func TestLLMExtractorExtractKeepsProjectedToolCallSpan(t *testing.T) {
- generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
+// TestLLMExtractorResolveDecisionUsesShortlist 验证决策阶段会携带 shortlist 并解析 update 结果。
+func TestLLMExtractorResolveDecisionUsesShortlist(t *testing.T) {
+ generator := &stubTextGenerator{
+ response: `{"action":"update","ref":"project:p.md","title":"测试策略","content":"用户要求修改后先跑相关测试。","keywords":["test"]}`,
+ }
+ extractor := NewLLMExtractor(generator)
- _, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("remember this")}},
- {
- Role: providertypes.RoleAssistant,
- ToolCalls: []providertypes.ToolCall{
- {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
+ decision, err := extractor.ResolveDecision(
+ context.Background(),
+ Entry{Type: TypeFeedback, Title: "测试策略", Content: "以后修改完先跑相关测试。"},
+ []ExtractionCandidate{
+ {
+ Ref: "project:p.md",
+ Scope: ScopeProject,
+ Type: TypeFeedback,
+ Source: SourceAutoExtract,
+ Title: "测试策略",
+ Keywords: []string{"test"},
+ Content: "用户要求修改后先跑测试。",
},
},
- {
- Role: providertypes.RoleTool,
- ToolCallID: "call_1",
- Parts: []providertypes.ContentPart{providertypes.NewTextPart("README body")},
- ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"},
- },
- })
+ )
if err != nil {
- t.Fatalf("Extract() error = %v", err)
+ t.Fatalf("ResolveDecision() error = %v", err)
}
- if len(generator.messages) != 3 {
- t.Fatalf("len(generator.messages) = %d, want 3", len(generator.messages))
+ if decision.Action != ExtractionActionUpdate || decision.Ref != "project:p.md" {
+ t.Fatalf("unexpected decision: %+v", decision)
}
- if generator.messages[1].Role != providertypes.RoleAssistant || len(generator.messages[1].ToolCalls) != 1 {
- t.Fatalf("expected assistant tool call span to be preserved, got %#v", generator.messages[1])
+ if decision.Entry.Type != "" {
+ t.Fatalf("update decision should not require type, got %+v", decision.Entry)
}
- toolMessage := generator.messages[2]
- if toolMessage.Role != providertypes.RoleTool {
- t.Fatalf("expected projected tool message, got %#v", toolMessage)
+ if len(generator.messages) != 1 || generator.messages[0].Role != providertypes.RoleUser {
+ t.Fatalf("expected one synthetic user message, got %+v", generator.messages)
}
- toolText := renderMemoParts(toolMessage.Parts)
- if !strings.Contains(toolText, "tool result") || !strings.Contains(toolText, "tool: filesystem_read_file") {
- t.Fatalf("expected projected tool text, got %q", toolText)
+ if !strings.Contains(generator.prompt, `"ref":"project:p.md"`) {
+ t.Fatalf("prompt should include shortlist candidates, got %q", generator.prompt)
}
- if toolMessage.ToolMetadata != nil {
- t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata)
+ if !strings.Contains(generator.prompt, `source="extractor_auto"`) {
+ t.Fatalf("prompt should describe auto-extracted update rule, got %q", generator.prompt)
}
}
-func TestLLMExtractorExtractKeepsMetadataOnlyToolCallSpan(t *testing.T) {
- generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
+// TestLLMExtractorResolveDecisionSupportsCreateAndSkip 验证单条决策支持 create 与 skip。
+func TestLLMExtractorResolveDecisionSupportsCreateAndSkip(t *testing.T) {
+ t.Run("create", func(t *testing.T) {
+ extractor := NewLLMExtractor(&stubTextGenerator{
+ response: `{"action":"create","type":"project","title":"上线计划","content":"项目将在 2026-05-12 上线。","keywords":["release"]}`,
+ })
+ decision, err := extractor.ResolveDecision(
+ context.Background(),
+ Entry{Type: TypeProject, Title: "上线计划", Content: "项目将在下周上线。"},
+ []ExtractionCandidate{{Ref: "project:old.md", Scope: ScopeProject, Type: TypeProject, Source: SourceUserManual, Title: "旧计划", Content: "历史计划"}},
+ )
+ if err != nil {
+ t.Fatalf("ResolveDecision(create) error = %v", err)
+ }
+ if decision.Action != ExtractionActionCreate || decision.Entry.Type != TypeProject {
+ t.Fatalf("unexpected create decision: %+v", decision)
+ }
+ })
- _, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("remember this")}},
- {
- Role: providertypes.RoleAssistant,
- ToolCalls: []providertypes.ToolCall{
- {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
- },
- },
- {
- Role: providertypes.RoleTool,
- ToolCallID: "call_1",
- ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"},
- },
+ t.Run("skip", func(t *testing.T) {
+ extractor := NewLLMExtractor(&stubTextGenerator{
+ response: `{"action":"skip","ref":"user:u.md"}`,
+ })
+ decision, err := extractor.ResolveDecision(
+ context.Background(),
+ Entry{Type: TypeUser, Title: "中文回复", Content: "用户偏好中文回复。"},
+ []ExtractionCandidate{{Ref: "user:u.md", Scope: ScopeUser, Type: TypeUser, Source: SourceUserManual, Title: "中文回复", Content: "用户偏好中文回复。"}},
+ )
+ if err != nil {
+ t.Fatalf("ResolveDecision(skip) error = %v", err)
+ }
+ if decision.Action != ExtractionActionSkip || decision.Ref != "user:u.md" {
+ t.Fatalf("unexpected skip decision: %+v", decision)
+ }
})
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(generator.messages) != 3 {
- t.Fatalf("len(generator.messages) = %d, want 3", len(generator.messages))
- }
- toolMessage := generator.messages[2]
- if toolMessage.Role != providertypes.RoleTool {
- t.Fatalf("expected projected tool message, got %#v", toolMessage)
- }
- toolText := renderMemoParts(toolMessage.Parts)
- if !strings.Contains(toolText, "tool result") ||
- !strings.Contains(toolText, "tool: filesystem_read_file") ||
- !strings.Contains(toolText, "meta.path: README.md") {
- t.Fatalf("expected metadata-only tool text, got %q", toolText)
- }
- if strings.Contains(toolText, "content:\n") {
- t.Fatalf("expected metadata-only projection to omit content section, got %q", toolText)
- }
- if toolMessage.ToolMetadata != nil {
- t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata)
- }
}
-func TestLLMExtractorExtractSkipsOrphanAndClearedToolMessages(t *testing.T) {
- generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
+// TestLLMExtractorErrors 验证取消、空生成器和上游错误会正确透传。
+func TestLLMExtractorErrors(t *testing.T) {
+ t.Run("canceled context", func(t *testing.T) {
+ generator := &stubTextGenerator{response: `[]`}
+ extractor := NewLLMExtractor(generator)
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
- _, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("alpha")}},
- {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan result")}},
- {
- Role: providertypes.RoleAssistant,
- ToolCalls: []providertypes.ToolCall{
- {ID: "call_1", Name: "bash", Arguments: `{}`},
- },
- },
- {Role: providertypes.RoleTool, ToolCallID: "call_1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("[Old tool result content cleared]")}},
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("beta")}},
- })
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(generator.messages) != 2 {
- t.Fatalf("len(generator.messages) = %d, want 2", len(generator.messages))
- }
- for _, message := range generator.messages {
- if message.Role == providertypes.RoleTool || len(message.ToolCalls) > 0 {
- t.Fatalf("unexpected tool-related message in extraction window: %#v", message)
+ _, err := extractor.Extract(ctx, []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
+ })
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("Extract() error = %v, want context.Canceled", err)
}
- }
-}
-
-func TestLLMExtractorExtractNilGenerator(t *testing.T) {
- var extractor *LLMExtractor
- _, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
})
- if err == nil || !strings.Contains(err.Error(), "text generator is nil") {
- t.Fatalf("Extract() error = %v", err)
- }
- extractor = NewLLMExtractor(nil, 10)
- _, err = extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
+ t.Run("nil generator", func(t *testing.T) {
+ extractor := NewLLMExtractor(nil)
+ _, err := extractor.Extract(context.Background(), []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
+ })
+ if err == nil || !strings.Contains(err.Error(), "text generator is nil") {
+ t.Fatalf("Extract() error = %v", err)
+ }
})
- if err == nil || !strings.Contains(err.Error(), "text generator is nil") {
- t.Fatalf("Extract() error = %v", err)
- }
-}
-func TestLLMExtractorExtractGeneratorFailure(t *testing.T) {
- extractor := NewLLMExtractor(&stubTextGenerator{err: errors.New("upstream failed")}, 10)
- _, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
+ t.Run("generator failure", func(t *testing.T) {
+ extractor := NewLLMExtractor(&stubTextGenerator{err: errors.New("upstream failed")})
+ _, err := extractor.Extract(context.Background(), []providertypes.Message{
+ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}},
+ })
+ if err == nil || !strings.Contains(err.Error(), "upstream failed") {
+ t.Fatalf("Extract() error = %v", err)
+ }
})
- if err == nil || !strings.Contains(err.Error(), "upstream failed") {
- t.Fatalf("Extract() error = %v", err)
- }
}
-func TestExtractJSONArrayErrors(t *testing.T) {
+// TestJSONPayloadExtractors 验证数组与对象提取器的错误分支。
+func TestJSONPayloadExtractors(t *testing.T) {
if _, err := extractJSONArray("no json here"); err == nil || !strings.Contains(err.Error(), "does not contain") {
t.Fatalf("expected missing array error, got %v", err)
}
if _, err := extractJSONArray(`[{"a":"x"}`); err == nil || !strings.Contains(err.Error(), "incomplete") {
t.Fatalf("expected incomplete array error, got %v", err)
}
-}
-
-func TestLLMExtractorExtractImageOnlyUserMessageSkipsGenerator(t *testing.T) {
- generator := &stubTextGenerator{response: `[]`}
- extractor := NewLLMExtractor(generator, 10)
-
- entries, err := extractor.Extract(context.Background(), []providertypes.Message{
- {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewRemoteImagePart("https://example.com/pic.png")}},
- })
- if err != nil {
- t.Fatalf("Extract() error = %v", err)
- }
- if len(entries) != 0 {
- t.Fatalf("len(entries) = %d, want 0", len(entries))
+ if _, err := extractJSONObject("no json here"); err == nil || !strings.Contains(err.Error(), "does not contain") {
+ t.Fatalf("expected missing object error, got %v", err)
}
- if generator.calls != 0 {
- t.Fatalf("Generate() calls = %d, want 0", generator.calls)
+ if _, err := extractJSONObject(`{"a":"x"`); err == nil || !strings.Contains(err.Error(), "incomplete") {
+ t.Fatalf("expected incomplete object error, got %v", err)
}
}
diff --git a/internal/memo/semantic_candidates.go b/internal/memo/semantic_candidates.go
new file mode 100644
index 00000000..01f911ca
--- /dev/null
+++ b/internal/memo/semantic_candidates.go
@@ -0,0 +1,321 @@
+package memo
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strings"
+ "unicode"
+)
+
+const (
+ semanticCandidateShortlistLimit = 5
+ semanticCandidateContentMaxRunes = 240
+)
+
+type scoredExtractionCandidate struct {
+ candidate ExtractionCandidate
+ score int
+}
+
+// semanticCandidateShortlist 为候选记忆检索相关 existing memo,并限制进入 LLM 的数量。
+func (s *Service) semanticCandidateShortlist(
+ ctx context.Context,
+ entry Entry,
+ limit int,
+) ([]ExtractionCandidate, error) {
+ if s == nil {
+ return nil, nil
+ }
+ if limit <= 0 {
+ limit = semanticCandidateShortlistLimit
+ }
+ if err := s.ensureSemanticCandidateIndex(ctx); err != nil {
+ return nil, err
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ scored := make([]scoredExtractionCandidate, 0, len(s.semanticCandidatesByRef))
+ for _, candidate := range s.semanticCandidatesByRef {
+ score := scoreExtractionCandidate(entry, candidate)
+ if score <= 0 {
+ continue
+ }
+ scored = append(scored, scoredExtractionCandidate{
+ candidate: cloneExtractionCandidate(candidate),
+ score: score,
+ })
+ }
+ sort.SliceStable(scored, func(i, j int) bool {
+ if scored[i].score != scored[j].score {
+ return scored[i].score > scored[j].score
+ }
+ return scored[i].candidate.Ref < scored[j].candidate.Ref
+ })
+ if len(scored) > limit {
+ scored = scored[:limit]
+ }
+
+ result := make([]ExtractionCandidate, 0, len(scored))
+ for _, item := range scored {
+ result = append(result, item.candidate)
+ }
+ return result, nil
+}
+
+// ensureSemanticCandidateIndex 懒加载语义去重候选索引,避免每轮 accepted run 都重扫 topic 文件。
+func (s *Service) ensureSemanticCandidateIndex(ctx context.Context) error {
+ s.mu.Lock()
+ if s.semanticIndexReady {
+ s.mu.Unlock()
+ return nil
+ }
+ s.mu.Unlock()
+
+ s.semanticIndexMu.Lock()
+ defer s.semanticIndexMu.Unlock()
+
+ s.mu.Lock()
+ if s.semanticIndexReady {
+ s.mu.Unlock()
+ return nil
+ }
+ s.mu.Unlock()
+
+ candidatesByRef := make(map[string]ExtractionCandidate)
+ for _, scope := range supportedStorageScopes() {
+ index, err := s.store.LoadIndex(ctx, scope)
+ if err != nil {
+ return fmt.Errorf("memo: load index: %w", err)
+ }
+ for _, entry := range index.Entries {
+ topicFile := strings.TrimSpace(entry.TopicFile)
+ if topicFile == "" {
+ continue
+ }
+ topicContent, err := s.store.LoadTopic(ctx, scope, topicFile)
+ if err != nil {
+ continue
+ }
+ candidate := buildExtractionCandidateFromTopic(scope, entry, topicContent)
+ if candidate.Ref == "" {
+ continue
+ }
+ candidatesByRef[candidate.Ref] = candidate
+ }
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.semanticIndexReady {
+ return nil
+ }
+ s.semanticCandidatesByRef = candidatesByRef
+ s.semanticIndexReady = true
+ return nil
+}
+
+// trackSemanticCandidateLocked 在语义索引已就绪时同步维护单条 memo 的 shortlist 快照。
+func (s *Service) trackSemanticCandidateLocked(scope Scope, entry Entry) {
+ if !s.semanticIndexReady {
+ return
+ }
+ topicFile := strings.TrimSpace(entry.TopicFile)
+ if topicFile == "" {
+ return
+ }
+ candidate := buildExtractionCandidateFromEntry(scope, entry)
+ if candidate.Ref == "" {
+ return
+ }
+ if s.semanticCandidatesByRef == nil {
+ s.semanticCandidatesByRef = make(map[string]ExtractionCandidate)
+ }
+ s.semanticCandidatesByRef[candidate.Ref] = candidate
+}
+
+// removeSemanticCandidateLocked 从语义索引中删除指定 topic 的候选快照。
+func (s *Service) removeSemanticCandidateLocked(scope Scope, topicFile string) {
+ if !s.semanticIndexReady {
+ return
+ }
+ topicFile = strings.TrimSpace(topicFile)
+ if topicFile == "" {
+ return
+ }
+ delete(s.semanticCandidatesByRef, scopedTopicKey(scope, topicFile))
+}
+
+// buildExtractionCandidateFromEntry 将内存中的完整 Entry 收敛为 shortlist 快照。
+func buildExtractionCandidateFromEntry(scope Scope, entry Entry) ExtractionCandidate {
+ topicFile := strings.TrimSpace(entry.TopicFile)
+ if topicFile == "" {
+ return ExtractionCandidate{}
+ }
+ return ExtractionCandidate{
+ Ref: scopedTopicKey(scope, topicFile),
+ Scope: scope,
+ Type: entry.Type,
+ Source: strings.TrimSpace(entry.Source),
+ Title: NormalizeTitle(entry.Title),
+ Keywords: normalizeKeywords(entry.Keywords),
+ Content: truncateSemanticContent(entry.Content),
+ }
+}
+
+// buildExtractionCandidateFromTopic 将 topic 文件内容解析为 shortlist 快照。
+func buildExtractionCandidateFromTopic(scope Scope, entry Entry, topicContent string) ExtractionCandidate {
+ source, keywords, content := parseTopicSnapshot(topicContent)
+ entry.Source = source
+ entry.Keywords = keywords
+ entry.Content = content
+ return buildExtractionCandidateFromEntry(scope, entry)
+}
+
+// parseTopicSnapshot 解析 topic frontmatter 中的 source、keywords 与正文内容。
+func parseTopicSnapshot(topic string) (string, []string, string) {
+ parts := strings.Split(topic, "---")
+ if len(parts) < 3 {
+ return "", nil, strings.TrimSpace(topic)
+ }
+
+ var (
+ source string
+ keywords []string
+ )
+ frontmatter := parts[1]
+ body := strings.TrimSpace(strings.Join(parts[2:], "---"))
+ for _, line := range strings.Split(frontmatter, "\n") {
+ line = strings.TrimSpace(line)
+ switch {
+ case strings.HasPrefix(line, "source:"):
+ source = strings.TrimSpace(strings.TrimPrefix(line, "source:"))
+ case strings.HasPrefix(line, "keywords:"):
+ raw := strings.TrimSpace(strings.TrimPrefix(line, "keywords:"))
+ keywords = parseTopicKeywords(raw)
+ }
+ }
+ return source, keywords, body
+}
+
+// parseTopicKeywords 解析 frontmatter 中的单行 keywords 列表。
+func parseTopicKeywords(raw string) []string {
+ raw = strings.TrimSpace(raw)
+ raw = strings.TrimPrefix(raw, "[")
+ raw = strings.TrimSuffix(raw, "]")
+ if raw == "" {
+ return nil
+ }
+
+ parts := strings.Split(raw, ",")
+ keywords := make([]string, 0, len(parts))
+ for _, part := range parts {
+ part = strings.TrimSpace(strings.Trim(part, `"'`))
+ if part == "" {
+ continue
+ }
+ keywords = append(keywords, part)
+ }
+ return normalizeKeywords(keywords)
+}
+
+// truncateSemanticContent 对候选正文做固定上限截断,避免把全量 memo 内容送入 prompt。
+func truncateSemanticContent(content string) string {
+ content = strings.TrimSpace(content)
+ if content == "" {
+ return ""
+ }
+ runes := []rune(content)
+ if len(runes) <= semanticCandidateContentMaxRunes {
+ return content
+ }
+ return string(runes[:semanticCandidateContentMaxRunes]) + "..."
+}
+
+// scoreExtractionCandidate 为 shortlist 候选打分,优先保留最相关的 memo。
+func scoreExtractionCandidate(target Entry, existing ExtractionCandidate) int {
+ targetTitle := strings.ToLower(NormalizeTitle(target.Title))
+ targetContent := strings.ToLower(strings.TrimSpace(target.Content))
+ existingTitle := strings.ToLower(NormalizeTitle(existing.Title))
+ existingContent := strings.ToLower(strings.TrimSpace(existing.Content))
+
+ score := 0
+ if target.Type == existing.Type {
+ score += 80
+ }
+ if targetTitle != "" && targetTitle == existingTitle {
+ score += 1000
+ }
+ if targetContent != "" && targetContent == existingContent {
+ score += 900
+ }
+ if targetTitle != "" && strings.Contains(existingContent, targetTitle) {
+ score += 120
+ }
+ if existingTitle != "" && strings.Contains(targetContent, existingTitle) {
+ score += 120
+ }
+
+ targetKeywordSet := tokenSet(append([]string{target.Title, target.Content}, target.Keywords...)...)
+ existingKeywordSet := tokenSet(append([]string{existing.Title, existing.Content}, existing.Keywords...)...)
+ for token := range targetKeywordSet {
+ if _, ok := existingKeywordSet[token]; ok {
+ score += 10
+ }
+ }
+ for _, keyword := range normalizeKeywords(target.Keywords) {
+ normalized := strings.ToLower(strings.TrimSpace(keyword))
+ if normalized == "" {
+ continue
+ }
+ for _, existingKeyword := range existing.Keywords {
+ if normalized == strings.ToLower(strings.TrimSpace(existingKeyword)) {
+ score += 40
+ break
+ }
+ }
+ }
+ return score
+}
+
+// tokenSet 将多段文本规整为去重后的 token 集合,供 shortlist 相关性排序使用。
+func tokenSet(parts ...string) map[string]struct{} {
+ set := make(map[string]struct{})
+ for _, part := range parts {
+ for _, token := range tokenizeSemanticText(part) {
+ set[token] = struct{}{}
+ }
+ }
+ return set
+}
+
+// tokenizeSemanticText 按字母数字边界切分文本,生成 shortlist 排序所需的归一化 token。
+func tokenizeSemanticText(text string) []string {
+ text = strings.TrimSpace(strings.ToLower(text))
+ if text == "" {
+ return nil
+ }
+ parts := strings.FieldsFunc(text, func(r rune) bool {
+ return !unicode.IsLetter(r) && !unicode.IsNumber(r)
+ })
+ tokens := make([]string, 0, len(parts))
+ for _, part := range parts {
+ part = strings.TrimSpace(part)
+ if part == "" {
+ continue
+ }
+ tokens = append(tokens, part)
+ }
+ return tokens
+}
+
+// cloneExtractionCandidate 深拷贝 shortlist 候选,避免切片字段共享底层数组。
+func cloneExtractionCandidate(candidate ExtractionCandidate) ExtractionCandidate {
+ cloned := candidate
+ if len(candidate.Keywords) > 0 {
+ cloned.Keywords = append([]string(nil), candidate.Keywords...)
+ }
+ return cloned
+}
diff --git a/internal/memo/service.go b/internal/memo/service.go
index 42551f26..2e30e0a3 100644
--- a/internal/memo/service.go
+++ b/internal/memo/service.go
@@ -14,14 +14,17 @@ import (
// Service 编排记忆的存储、检索、删除和索引维护,是 memo 子系统对外的统一入口。
type Service struct {
- store Store
- config config.MemoConfig
- mu sync.Mutex
- sourceInvl func()
- autoExtractIndexMu sync.Mutex
- autoExtractIndexReady bool
- autoExtractKeysByTopic map[string]string
- autoExtractKeyRefs map[string]int
+ store Store
+ config config.MemoConfig
+ mu sync.Mutex
+ sourceInvl func()
+ autoExtractIndexMu sync.Mutex
+ autoExtractIndexReady bool
+ autoExtractKeysByTopic map[string]string
+ autoExtractKeyRefs map[string]int
+ semanticIndexMu sync.Mutex
+ semanticIndexReady bool
+ semanticCandidatesByRef map[string]ExtractionCandidate
}
// NewService 创建 memo Service 实例。
@@ -187,6 +190,15 @@ func (s *Service) updateAutoExtractIfAllowed(ctx context.Context, ref string, ne
s.trackAutoExtractEntryLocked(scope, current)
}
}
+ if s.semanticIndexReady {
+ s.removeSemanticCandidateLocked(scope, topicFile)
+ for _, removed := range removedEntries {
+ s.removeSemanticCandidateLocked(scope, removed.TopicFile)
+ }
+ if indexContainsTopicFile(working, topicFile) {
+ s.trackSemanticCandidateLocked(scope, current)
+ }
+ }
s.invalidateCache()
return true, nil
@@ -239,6 +251,7 @@ func (s *Service) Remove(ctx context.Context, keyword string, scope Scope) (int,
if topicFile := strings.TrimSpace(entry.TopicFile); topicFile != "" {
_ = s.store.DeleteTopic(ctx, bucket, topicFile)
s.removeAutoExtractTopicLocked(bucket, topicFile)
+ s.removeSemanticCandidateLocked(bucket, topicFile)
}
}
removed += len(removedEntries)
@@ -393,6 +406,17 @@ func (s *Service) saveEntryLocked(ctx context.Context, entry Entry) error {
s.trackAutoExtractEntryLocked(scope, entry)
}
}
+ if s.semanticIndexReady {
+ if replaced {
+ s.removeSemanticCandidateLocked(scope, previous.TopicFile)
+ }
+ for _, removed := range removedEntries {
+ s.removeSemanticCandidateLocked(scope, removed.TopicFile)
+ }
+ if indexContainsEntryID(working, entry.ID) {
+ s.trackSemanticCandidateLocked(scope, entry)
+ }
+ }
s.invalidateCache()
return nil
diff --git a/internal/memo/service_test.go b/internal/memo/service_test.go
index 988b71ff..462d17f3 100644
--- a/internal/memo/service_test.go
+++ b/internal/memo/service_test.go
@@ -3,6 +3,7 @@ package memo
import (
"context"
"errors"
+ "fmt"
"strings"
"testing"
@@ -11,10 +12,9 @@ import (
func testMemoConfig() config.MemoConfig {
return config.MemoConfig{
- MaxEntries: 200,
- MaxIndexBytes: 16 * 1024,
- ExtractTimeoutSec: 15,
- ExtractRecentMessages: 10,
+ MaxEntries: 200,
+ MaxIndexBytes: 16 * 1024,
+ ExtractTimeoutSec: 15,
}
}
@@ -300,10 +300,9 @@ func TestServiceAutoExtractDedupAcrossScopes(t *testing.T) {
func TestServiceAutoExtractTrimmedEntryDoesNotPolluteDedupIndex(t *testing.T) {
svc := NewService(newMemoryTestStore(), config.MemoConfig{
- MaxEntries: 10,
- MaxIndexBytes: 1,
- ExtractTimeoutSec: 15,
- ExtractRecentMessages: 10,
+ MaxEntries: 10,
+ MaxIndexBytes: 1,
+ ExtractTimeoutSec: 15,
}, nil)
entry := Entry{
Type: TypeUser,
@@ -375,6 +374,82 @@ func TestMatchesKeywordIncludesContent(t *testing.T) {
}
}
+func TestServiceSemanticCandidateShortlistMatchesRelevantMemo(t *testing.T) {
+ store := NewFileStore(t.TempDir(), t.TempDir())
+ svc := NewService(store, testMemoConfig(), nil)
+
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeFeedback,
+ Title: "测试策略",
+ Content: "用户要求修改后先跑测试。",
+ Source: SourceAutoExtract,
+ }); err != nil {
+ t.Fatalf("seed relevant Add() error = %v", err)
+ }
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeUser,
+ Title: "语言偏好",
+ Content: "用户偏好中文回复。",
+ Source: SourceUserManual,
+ }); err != nil {
+ t.Fatalf("seed unrelated Add() error = %v", err)
+ }
+
+ shortlist, err := svc.semanticCandidateShortlist(context.Background(), Entry{
+ Type: TypeFeedback,
+ Title: "测试策略",
+ Content: "用户要求修改后先跑相关测试。",
+ }, semanticCandidateShortlistLimit)
+ if err != nil {
+ t.Fatalf("semanticCandidateShortlist() error = %v", err)
+ }
+ if len(shortlist) != 1 {
+ t.Fatalf("len(shortlist) = %d, want 1", len(shortlist))
+ }
+ if shortlist[0].Title != "测试策略" || shortlist[0].Source != SourceAutoExtract {
+ t.Fatalf("unexpected shortlist candidate: %+v", shortlist[0])
+ }
+}
+
+func TestServiceSemanticCandidateShortlistLimitsAndTruncates(t *testing.T) {
+ store := NewFileStore(t.TempDir(), t.TempDir())
+ svc := NewService(store, testMemoConfig(), nil)
+
+ longContent := strings.Repeat("release-note ", 40)
+ for index := 0; index < 6; index++ {
+ if err := svc.Add(context.Background(), Entry{
+ Type: TypeProject,
+ Title: fmt.Sprintf("发布计划 %d", index),
+ Content: longContent + fmt.Sprintf(" milestone-%d", index),
+ Keywords: []string{"release"},
+ Source: SourceAutoExtract,
+ }); err != nil {
+ t.Fatalf("seed Add(%d) error = %v", index, err)
+ }
+ }
+
+ shortlist, err := svc.semanticCandidateShortlist(context.Background(), Entry{
+ Type: TypeProject,
+ Title: "发布计划",
+ Content: "release milestone",
+ Keywords: []string{"release"},
+ }, semanticCandidateShortlistLimit)
+ if err != nil {
+ t.Fatalf("semanticCandidateShortlist() error = %v", err)
+ }
+ if len(shortlist) != semanticCandidateShortlistLimit {
+ t.Fatalf("len(shortlist) = %d, want %d", len(shortlist), semanticCandidateShortlistLimit)
+ }
+ for _, candidate := range shortlist {
+ if len([]rune(candidate.Content)) > semanticCandidateContentMaxRunes+3 {
+ t.Fatalf("candidate content should be truncated, got %q", candidate.Content)
+ }
+ if len(candidate.Keywords) != 1 || candidate.Keywords[0] != "release" {
+ t.Fatalf("candidate keywords should be preserved, got %+v", candidate)
+ }
+ }
+}
+
func TestTrimIndexEntriesByBytesRemovesMinimalPrefix(t *testing.T) {
index := &Index{
Entries: []Entry{
diff --git a/internal/memo/types.go b/internal/memo/types.go
index 0d1a62a0..04d2d084 100644
--- a/internal/memo/types.go
+++ b/internal/memo/types.go
@@ -103,12 +103,13 @@ const (
// ExtractionCandidate 表示提供给模型做语义去重的既有记忆快照。
type ExtractionCandidate struct {
- Ref string `json:"ref"`
- Scope Scope `json:"scope"`
- Type Type `json:"type"`
- Source string `json:"source"`
- Title string `json:"title"`
- Content string `json:"content"`
+ Ref string `json:"ref"`
+ Scope Scope `json:"scope"`
+ Type Type `json:"type"`
+ Source string `json:"source"`
+ Title string `json:"title"`
+ Keywords []string `json:"keywords,omitempty"`
+ Content string `json:"content"`
}
// ExtractionDecision 表示模型针对新旧记忆关系返回的结构化决策。
@@ -133,13 +134,13 @@ type Extractor interface {
Extract(ctx context.Context, messages []providertypes.Message) ([]Entry, error)
}
-// DecisionExtractor 定义带既有记忆快照的语义提取能力。
-type DecisionExtractor interface {
- ExtractDecisions(
+// DecisionResolver 定义针对单条候选记忆做去重决策的最小能力。
+type DecisionResolver interface {
+ ResolveDecision(
ctx context.Context,
- messages []providertypes.Message,
+ candidate Entry,
existing []ExtractionCandidate,
- ) ([]ExtractionDecision, error)
+ ) (ExtractionDecision, error)
}
// TextGenerator 定义调用 LLM 生成文本的最小能力,用于记忆提取。
diff --git a/internal/tools/memo/remember_test.go b/internal/tools/memo/remember_test.go
index ac0d39f8..bf4defef 100644
--- a/internal/tools/memo/remember_test.go
+++ b/internal/tools/memo/remember_test.go
@@ -15,10 +15,9 @@ func newTestService(t *testing.T) *memo.Service {
t.Helper()
store := memo.NewFileStore(t.TempDir(), t.TempDir())
return memo.NewService(store, config.MemoConfig{
- MaxEntries: 200,
- MaxIndexBytes: 16 * 1024,
- ExtractTimeoutSec: 15,
- ExtractRecentMessages: 10,
+ MaxEntries: 200,
+ MaxIndexBytes: 16 * 1024,
+ ExtractTimeoutSec: 15,
}, nil)
}