diff --git a/docs/reference/gateway-rpc-api.md b/docs/reference/gateway-rpc-api.md index 3fc19a47..4ec01eec 100644 --- a/docs/reference/gateway-rpc-api.md +++ b/docs/reference/gateway-rpc-api.md @@ -440,7 +440,7 @@ Response Schema: "saved_ratio": 0.63, "trigger_mode": "manual", "transcript_id": "compact-1", - "transcript_path": ".neocode/transcripts/compact-1.md" + "transcript_path": ".neocode/transcripts/compact-subagent.md" } } } @@ -1147,7 +1147,7 @@ Success Response: "SavedRatio": 0.63, "TriggerMode": "manual", "TranscriptID": "compact-demo-1", - "TranscriptPath": ".neocode/transcripts/compact-demo-1.md" + "TranscriptPath": ".neocode/transcripts/compact-demo-subagent.md" } } } diff --git a/go.mod b/go.mod index 19f62d3a..2d6feada 100644 --- a/go.mod +++ b/go.mod @@ -80,6 +80,7 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/odvcencio/gotreesitter v0.15.3 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect diff --git a/go.sum b/go.sum index c767cc53..3a974b11 100644 --- a/go.sum +++ b/go.sum @@ -191,6 +191,10 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/odvcencio/gotreesitter v0.15.3 h1:bcSIEMyRrDaFIZw2zwM8cNR03VX6y8CbXFxVzfFSGX0= +github.com/odvcencio/gotreesitter v0.15.3/go.mod h1:ccYZsDUmAJQAtliLsNHT33F3X4AN7f/Z6JGiPNZoEzY= +github.com/openai/openai-go/v3 v3.30.0 h1:T8VkhqAm6BuvxwpVG+Aw+H4TcYIsbj9nqytjpWcE/aU= +github.com/openai/openai-go/v3 v3.30.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/openai/openai-go/v3 v3.32.0 h1:aHp/3wkX1W6jB8zTtf9xV0aK0qPFSVDqS7AHmlJ4hXs= github.com/openai/openai-go/v3 v3.32.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 96211b42..358c03e2 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -21,12 +21,14 @@ import ( "neo-code/internal/provider/builtin" providercatalog "neo-code/internal/provider/catalog" providertypes "neo-code/internal/provider/types" + "neo-code/internal/repository" agentruntime "neo-code/internal/runtime" "neo-code/internal/security" agentsession "neo-code/internal/session" "neo-code/internal/skills" "neo-code/internal/tools" "neo-code/internal/tools/bash" + "neo-code/internal/tools/codebase" diagnosetool "neo-code/internal/tools/diagnose" "neo-code/internal/tools/filesystem" "neo-code/internal/tools/mcp" @@ -460,6 +462,10 @@ func buildToolRegistry(cfg config.Config) (*tools.Registry, func() error, error) })) toolRegistry.Register(todo.New()) toolRegistry.Register(spawnsubagent.New()) + repoSvc := repository.NewService() + toolRegistry.Register(codebase.NewRead(repoSvc, cfg.Workdir)) + toolRegistry.Register(codebase.NewSearchText(repoSvc, cfg.Workdir)) + toolRegistry.Register(codebase.NewSearchSymbol(repoSvc, cfg.Workdir)) mcpRegistry, err := BuildMCPRegistry(cfg) if err != nil { return nil, nil, err diff --git a/internal/checkpoint/bash_capture_test.go b/internal/checkpoint/bash_capture_test.go index ecfbf34f..246fc17b 100644 --- a/internal/checkpoint/bash_capture_test.go +++ b/internal/checkpoint/bash_capture_test.go @@ -220,3 +220,15 @@ func TestSourceFilesInWorkdir_HandlesEmptyWorkdir(t *testing.T) { t.Fatalf("expected nil with empty workdir, got %v", got) } } + +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/checkpoint/per_edit_snapshot.go b/internal/checkpoint/per_edit_snapshot.go index 70fbd2b7..3c759655 100644 --- a/internal/checkpoint/per_edit_snapshot.go +++ b/internal/checkpoint/per_edit_snapshot.go @@ -17,6 +17,8 @@ import ( "time" "github.com/pmezard/go-difflib/difflib" + + "neo-code/internal/repository" ) const ( @@ -446,20 +448,17 @@ func (s *PerEditSnapshotStore) HasPending() bool { return len(s.pending) > 0 } -// FileChangeKind 表示两个 checkpoint 之间单个 path 的变更类别。 -type FileChangeKind string +// FileChangeKind 是 repository.FileChangeKind 的别名,保留以维持向后兼容。 +type FileChangeKind = repository.FileChangeKind const ( - FileChangeAdded FileChangeKind = "added" - FileChangeDeleted FileChangeKind = "deleted" - FileChangeModified FileChangeKind = "modified" + FileChangeAdded = repository.FileChangeAdded + FileChangeDeleted = repository.FileChangeDeleted + FileChangeModified = repository.FileChangeModified ) -// FileChangeEntry 描述端到端 diff 中单个 path 的变更。 -type FileChangeEntry struct { - Path string - Kind FileChangeKind -} +// FileChangeEntry 是 repository.FileChangeEntry 的别名,保留以维持向后兼容。 +type FileChangeEntry = repository.FileChangeEntry // ChangedFiles 端到端比较两个 checkpoint,返回 path → 变更类别的列表(按 path 字典序)。 // 不返回内容差异,仅用于 UI 分组(添加/删除/修改)。完整 patch 仍由 Diff 生成。 diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 3ae60863..283d1695 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -591,7 +591,7 @@ func TestGatewayRuntimePortBridgeRuntimeMethods(t *testing.T) { SavedRatio: 0.5, TriggerMode: "manual", TranscriptID: "tx-1", - TranscriptPath: "/tmp/tx-1.md", + TranscriptPath: "/tmp/tx-subagent.md", }, systemToolRes: tools.ToolResult{ ToolCallID: "call-system-1", diff --git a/internal/config/provider.go b/internal/config/provider.go index fc968a1d..494c8994 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -734,7 +734,7 @@ const ( GLMDefaultAPIKeyEnv = "GLM_API_KEY" MiMoName = "mimo" - MiMoDefaultBaseURL = "https://api.xiaomimimo.com/v1" + MiMoDefaultBaseURL = "https://token-plan-cn.xiaomimimo.com/v1" MiMoDefaultModel = "mimo-v2.5-pro" MiMoDefaultAPIKeyEnv = "MIMO_API_KEY" diff --git a/internal/context/prompt_test.go b/internal/context/prompt_test.go index 7356cbc7..7c933e8f 100644 --- a/internal/context/prompt_test.go +++ b/internal/context/prompt_test.go @@ -167,7 +167,7 @@ func TestDefaultToolUsagePromptIncludesPermissionAndAntiLoopGuidance(t *testing. if !strings.Contains(toolUsage, "`status`, `ok`, `tool_call_id`, `truncated`, `meta.*`, exit codes, and `content`") { t.Fatalf("expected Tool Usage to explain structured tool results, got %q", toolUsage) } - if !strings.Contains(toolUsage, "inspect (`git status`/`git diff`/`git log`)") { + if !strings.Contains(toolUsage, "Use Git through dedicated `git_*` tools") { t.Fatalf("expected Tool Usage to describe git inspection order, got %q", toolUsage) } if !strings.Contains(toolUsage, "Prefer rollback primitives in this order: `git restore`") { diff --git a/internal/context/repository/repository_additional_test.go b/internal/context/repository/repository_additional_test.go deleted file mode 100644 index aabcd7e6..00000000 --- a/internal/context/repository/repository_additional_test.go +++ /dev/null @@ -1,1009 +0,0 @@ -package repository - -import ( - "context" - "errors" - "fmt" - "os" - "path/filepath" - "slices" - "strings" - "testing" -) - -func TestLoadGitSnapshotGuardsAndErrorFallbacks(t *testing.T) { - t.Parallel() - - t.Run("context canceled", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - snapshot, err := (&Service{}).loadGitSnapshot(ctx, t.TempDir()) - if !errors.Is(err, context.Canceled) { - t.Fatalf("loadGitSnapshot() err = %v, want context canceled", err) - } - if snapshot.InGitRepo || snapshot.Branch != "" || snapshot.Ahead != 0 || snapshot.Behind != 0 || len(snapshot.Entries) != 0 { - t.Fatalf("expected empty snapshot, got %+v", snapshot) - } - }) - - t.Run("empty workdir or nil runner", func(t *testing.T) { - t.Parallel() - - service := &Service{} - if snapshot, err := service.loadGitSnapshot(context.Background(), " "); err != nil || snapshot.InGitRepo || len(snapshot.Entries) != 0 { - t.Fatalf("loadGitSnapshot(empty) = (%+v, %v), want empty nil", snapshot, err) - } - }) - - t.Run("non git returns empty and generic error bubbles up", func(t *testing.T) { - t.Parallel() - - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - return gitCommandOutput{text: "fatal: not a git repository"}, errors.New("exit status 128") - }, - } - snapshot, err := service.loadGitSnapshot(context.Background(), t.TempDir()) - if err != nil { - t.Fatalf("loadGitSnapshot(non-git) err = %v", err) - } - if snapshot.InGitRepo || len(snapshot.Entries) != 0 { - t.Fatalf("expected empty snapshot, got %+v", snapshot) - } - - service.gitRunner = func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - return gitCommandOutput{}, errors.New("boom") - } - _, err = service.loadGitSnapshot(context.Background(), t.TempDir()) - if err == nil { - t.Fatalf("expected generic git error to bubble up") - } - }) - - t.Run("context error from runner bubbles up", func(t *testing.T) { - t.Parallel() - - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - return gitCommandOutput{}, context.DeadlineExceeded - }, - } - _, err := service.loadGitSnapshot(context.Background(), t.TempDir()) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("loadGitSnapshot() err = %v, want deadline exceeded", err) - } - }) -} - -func TestChangedFileSnippetBranches(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "modified.go"), "package pkg\n\nfunc New(){}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "renamed.go"), "package pkg\n\nfunc Renamed(){}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "added.go"), "package pkg\n\nfunc Added() {}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "untracked.go"), "package pkg\n\nfunc NewFile() {}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "error.go"), "package pkg\n\nfunc Error(){}\n") - - service := &Service{ - gitRunner: func(ctx context.Context, dir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - command := strings.Join(args, " ") - switch command { - case "diff --unified=3 HEAD -- pkg/modified.go": - return gitCommandOutput{text: "@@ -1,1 +1,1 @@\n-func Old(){}\n+func New(){}\n"}, nil - case "diff --unified=3 HEAD -- pkg/renamed.go": - return gitCommandOutput{text: "@@ -1,1 +1,1 @@\n-old\n+new\n"}, nil - case "diff --unified=3 HEAD -- pkg/error.go": - return gitCommandOutput{}, context.Canceled - default: - return gitCommandOutput{}, nil - } - }, - readFile: readFile, - } - - tests := []struct { - name string - entry gitChangedEntry - wantErr error - wantSnippet string - }{ - {name: "deleted", entry: gitChangedEntry{Path: "pkg/deleted.go", Status: StatusDeleted}}, - {name: "conflicted", entry: gitChangedEntry{Path: "pkg/conflicted.go", Status: StatusConflicted}}, - {name: "modified", entry: gitChangedEntry{Path: "pkg/modified.go", Status: StatusModified}, wantSnippet: "func New"}, - {name: "renamed", entry: gitChangedEntry{Path: "pkg/renamed.go", Status: StatusRenamed}, wantSnippet: "+new"}, - {name: "added reads file head", entry: gitChangedEntry{Path: "pkg/added.go", Status: StatusAdded}, wantSnippet: "func Added"}, - {name: "untracked file head", entry: gitChangedEntry{Path: "pkg/untracked.go", Status: StatusUntracked}, wantSnippet: "func NewFile"}, - {name: "context error", entry: gitChangedEntry{Path: "pkg/error.go", Status: StatusModified}, wantErr: context.Canceled}, - {name: "unknown status", entry: gitChangedEntry{Path: "pkg/unknown.go", Status: ChangedFileStatus("other")}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - snippet, err := service.changedFileSnippet(context.Background(), workdir, tt.entry) - if tt.wantErr != nil { - if !errors.Is(err, tt.wantErr) { - t.Fatalf("changedFileSnippet() err = %v, want %v", err, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("changedFileSnippet() err = %v", err) - } - if tt.wantSnippet != "" && !strings.Contains(snippet.text, tt.wantSnippet) { - t.Fatalf("snippet %q does not contain %q", snippet.text, tt.wantSnippet) - } - }) - } -} - -func TestInspectPreservesSummaryWhenSingleSnippetReadFails(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "changed.go"), "package pkg\n") - service := &Service{ - gitRunner: func(ctx context.Context, dir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return gitCommandOutput{text: nulJoin("## main", " M pkg/changed.go")}, nil - case "diff --unified=3 HEAD -- pkg/changed.go": - return gitCommandOutput{}, errors.New("diff failed") - default: - return gitCommandOutput{}, nil - } - }, - readFile: readFile, - } - - result, err := service.Inspect(context.Background(), workdir, InspectOptions{ - ChangedFilesLimit: 10, - IncludeChangedFileSnippets: true, - }) - if err != nil { - t.Fatalf("Inspect() error = %v", err) - } - if !result.Summary.InGitRepo || result.Summary.Branch != "main" { - t.Fatalf("unexpected summary: %+v", result.Summary) - } - if len(result.ChangedFiles.Files) != 1 { - t.Fatalf("unexpected changed-files context: %+v", result.ChangedFiles) - } - if result.ChangedFiles.Files[0].Snippet != "" { - t.Fatalf("expected failed snippet to be dropped, got %q", result.ChangedFiles.Files[0].Snippet) - } -} - -func TestRetrieveUsesResolvedTargetForSnippetGate(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, ".env"), "SECRET=1\n") - hugeContent := strings.Repeat("A", maxRepositorySnippetFileBytes+1) - mustWriteFile(t, filepath.Join(workdir, "huge.txt"), hugeContent) - - if err := os.Symlink(filepath.Join(workdir, ".env"), filepath.Join(workdir, "safe.txt")); err != nil { - t.Skipf("symlink unsupported in test environment: %v", err) - } - if err := os.Symlink(filepath.Join(workdir, "huge.txt"), filepath.Join(workdir, "safe_link.txt")); err != nil { - t.Skipf("symlink unsupported in test environment: %v", err) - } - - service := &Service{readFile: readFile} - - pathResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "safe.txt", - }) - if err != nil { - t.Fatalf("Retrieve(path) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected safe.txt alias to be gated, got %+v", pathResult.Hits) - } - - textResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "SECRET", - }) - if err != nil { - t.Fatalf("Retrieve(text) error = %v", err) - } - if len(textResult.Hits) != 0 { - t.Fatalf("expected .env alias to be excluded from text retrieval, got %+v", textResult.Hits) - } - - largeResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "safe_link.txt", - }) - if err != nil { - t.Fatalf("Retrieve(path large) error = %v", err) - } - if len(largeResult.Hits) != 0 { - t.Fatalf("expected symlinked large file to be gated, got %+v", largeResult.Hits) - } -} - -func TestSnippetReadersAndParsers(t *testing.T) { - t.Parallel() - - t.Run("read diff snippet fallbacks", func(t *testing.T) { - t.Parallel() - - if snippet, err := ((*Service)(nil)).readDiffSnippet(context.Background(), "", "a.go"); err != nil || snippet != (snippetResult{}) { - t.Fatalf("nil service readDiffSnippet = (%+v, %v)", snippet, err) - } - - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - return gitCommandOutput{}, errors.New("ignored") - }, - } - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "a.go"), "package main\n") - if _, err := service.readDiffSnippet(context.Background(), workdir, "a.go"); err == nil { - t.Fatalf("expected readDiffSnippet non-context error to bubble up") - } - - service.gitRunner = func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - return gitCommandOutput{}, context.DeadlineExceeded - } - _, err := service.readDiffSnippet(context.Background(), workdir, "a.go") - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("readDiffSnippet() err = %v, want deadline exceeded", err) - } - }) - - t.Run("read file head snippet fallbacks", func(t *testing.T) { - t.Parallel() - - if snippet, err := ((*Service)(nil)).readFileHeadSnippet("", "a.go"); err != nil || snippet != (snippetResult{}) { - t.Fatalf("nil service readFileHeadSnippet = (%+v, %v)", snippet, err) - } - workdir := t.TempDir() - service := &Service{readFile: readFile} - _, err := service.readFileHeadSnippet(workdir, "../escape.txt") - if err == nil { - t.Fatalf("expected path escape error") - } - - service.readFile = func(path string) ([]byte, error) { - return nil, errors.New("read failed") - } - mustWriteFile(t, filepath.Join(workdir, "existing.txt"), "ok") - _, err = service.readFileHeadSnippet(workdir, "existing.txt") - if err == nil { - t.Fatalf("expected readFileHeadSnippet to return read error") - } - }) -} - -func TestChangedFilesMarksTruncatedWhenDiffOutputHitsByteCap(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "changed.go"), "package pkg\n") - service := &Service{ - gitRunner: func(ctx context.Context, dir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return gitCommandOutput{text: nulJoin("## main", " M pkg/changed.go")}, nil - case "diff --unified=3 HEAD -- pkg/changed.go": - return gitCommandOutput{ - text: "@@ -1,1 +1,2 @@\n-" + strings.Repeat("x", maxSnippetLineRunes+32) + "\n+" + strings.Repeat("y", maxSnippetLineRunes+32), - truncated: true, - }, nil - default: - return gitCommandOutput{}, nil - } - }, - readFile: readFile, - } - - result, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - }) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - if !result.Truncated { - t.Fatalf("expected changed-files context to mark diff byte truncation") - } - if len(result.Files) != 1 || result.Files[0].Snippet == "" { - t.Fatalf("expected snippet output, got %+v", result.Files) - } - for _, line := range strings.Split(result.Files[0].Snippet, "\n") { - if len([]rune(line)) > maxSnippetLineRunes { - t.Fatalf("expected snippet line to be capped at %d runes, got %d", maxSnippetLineRunes, len([]rune(line))) - } - } -} - -func TestGitParsingHelpers(t *testing.T) { - t.Parallel() - - branch, ahead, behind := parseBranchLine("") - if branch != "" || ahead != 0 || behind != 0 { - t.Fatalf("parseBranchLine(empty) = (%q,%d,%d)", branch, ahead, behind) - } - branch, ahead, behind = parseBranchLine("No commits yet on feature/test") - if branch != "feature/test" || ahead != 0 || behind != 0 { - t.Fatalf("parseBranchLine(no commits) = (%q,%d,%d)", branch, ahead, behind) - } - branch, _, _ = parseBranchLine("HEAD (no branch)") - if branch != "detached" { - t.Fatalf("parseBranchLine(detached) = %q", branch) - } - branch, ahead, behind = parseBranchLine("feature/x...origin/feature/x [ahead 2, behind 1]") - if branch != "feature/x" || ahead != 2 || behind != 1 { - t.Fatalf("parseBranchLine(tracking) = (%q,%d,%d)", branch, ahead, behind) - } - branch, ahead, behind = parseBranchLine("main [ahead nope, behind 3]") - if branch != "main" || ahead != 0 || behind != 3 { - t.Fatalf("parseBranchLine(invalid ahead value) = (%q,%d,%d)", branch, ahead, behind) - } - - tests := []struct { - records []string - ok bool - consumed int - status ChangedFileStatus - path string - oldPath string - }{ - {records: nil, ok: false, consumed: 1}, - {records: []string{"?? "}, ok: false, consumed: 1}, - {records: []string{"?? pkg/new.go"}, ok: true, consumed: 1, status: StatusUntracked, path: filepath.Clean("pkg/new.go")}, - {records: []string{"R new.go", "old.go"}, ok: true, consumed: 2, status: StatusRenamed, path: filepath.Clean("new.go"), oldPath: filepath.Clean("old.go")}, - {records: []string{"C copied.go", "source.go"}, ok: true, consumed: 2, status: StatusCopied, path: filepath.Clean("copied.go"), oldPath: filepath.Clean("source.go")}, - {records: []string{" M pkg/mod.go"}, ok: true, consumed: 1, status: StatusModified, path: filepath.Clean("pkg/mod.go")}, - {records: []string{" D pkg/deleted.go"}, ok: true, consumed: 1, status: StatusDeleted, path: filepath.Clean("pkg/deleted.go")}, - {records: []string{"XY file.txt"}, ok: false, consumed: 1}, - } - for _, tt := range tests { - got, consumed, ok := parseChangedRecord(tt.records) - if ok != tt.ok { - t.Fatalf("parseChangedRecord(%v) ok=%t, want %t", tt.records, ok, tt.ok) - } - if consumed != tt.consumed { - t.Fatalf("parseChangedRecord(%v) consumed=%d, want %d", tt.records, consumed, tt.consumed) - } - if !ok { - continue - } - if got.Status != tt.status || got.Path != tt.path || got.OldPath != tt.oldPath { - t.Fatalf("parseChangedRecord(%v) = %+v, want status=%q path=%q old=%q", tt.records, got, tt.status, tt.path, tt.oldPath) - } - } - - if normalizeStatus('U', 'A') != StatusConflicted || - normalizeStatus('R', ' ') != StatusRenamed || - normalizeStatus('C', ' ') != StatusCopied || - normalizeStatus('D', ' ') != StatusDeleted || - normalizeStatus('A', ' ') != StatusAdded || - normalizeStatus('M', ' ') != StatusModified || - normalizeStatus('X', 'Y') != "" { - t.Fatalf("normalizeStatus() mapping mismatch") - } -} - -func TestPathAndRetrievalHelpers(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "a.go"), "package pkg\n\nconst Name = \"Widget\"\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "b.txt"), "Widget appears twice\nWidget\n") - if err := os.MkdirAll(filepath.Join(workdir, "node_modules"), 0o755); err != nil { - t.Fatalf("MkdirAll() error = %v", err) - } - mustWriteFile(t, filepath.Join(workdir, "node_modules", "ignored.txt"), "ignored") - - t.Run("normalize retrieval query", func(t *testing.T) { - t.Parallel() - - _, _, _, err := normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalModePath, Value: " "}) - if err == nil { - t.Fatalf("expected empty query error") - } - _, _, _, err = normalizeRetrievalQuery(string([]byte{0}), RetrievalQuery{Mode: RetrievalModePath, Value: "a"}) - if err == nil { - t.Fatalf("expected invalid workdir error") - } - _, _, _, err = normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalMode("x"), Value: "a"}) - if !errors.Is(err, errInvalidMode) { - t.Fatalf("normalizeRetrievalQuery invalid mode err = %v", err) - } - _, _, _, err = normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalModePath, Value: "a", ScopeDir: ".."}) - if err == nil { - t.Fatalf("expected scope traversal error") - } - _, _, _, err = normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalModePath, Value: "a", ScopeDir: "pkg/a.go"}) - if err == nil { - t.Fatalf("expected scope is not dir error") - } - - root, scope, normalized, err := normalizeRetrievalQuery(workdir, RetrievalQuery{ - Mode: RetrievalModeText, - Value: " Widget ", - Limit: 999, - ContextLines: -1, - }) - if err != nil { - t.Fatalf("normalizeRetrievalQuery() err = %v", err) - } - if root == "" || scope == "" { - t.Fatalf("expected resolved root/scope") - } - if normalized.Value != "Widget" || normalized.Limit != maxRetrievalLimit || normalized.ContextLines != defaultContextLines { - t.Fatalf("unexpected normalized query: %+v", normalized) - } - }) - - t.Run("line helpers and walkers", func(t *testing.T) { - t.Parallel() - - lines := splitNonEmptyLines("a\r\n\n b \n\t\nc") - if !slices.Equal(lines, []string{"a", " b ", "c"}) { - t.Fatalf("splitNonEmptyLines() = %#v", lines) - } - if snippet := trimSnippetText("", 2); snippet != (snippetResult{}) { - t.Fatalf("expected empty snippet for empty input") - } - if snippet := trimSnippetText("a\nb\nc", 2); !snippet.truncated || snippet.lines != 2 { - t.Fatalf("trimSnippetText() = %+v, want truncated 2 lines", snippet) - } - - text, hint := snippetAroundLine("line1\nline2\nline3", 99, 1) - if hint != 3 || !strings.Contains(text, "line3") { - t.Fatalf("snippetAroundLine() = (%q,%d)", text, hint) - } - if text, hint = snippetAroundLine("", 1, 1); text != "" || hint != 1 { - t.Fatalf("snippetAroundLine(empty) = (%q,%d)", text, hint) - } - - visited := make([]string, 0, 2) - err := walkWorkspaceFiles(context.Background(), workdir, workdir, func(path string) error { - visited = append(visited, filepath.Base(path)) - return nil - }) - if err != nil { - t.Fatalf("walkWorkspaceFiles() err = %v", err) - } - if slices.Contains(visited, "ignored.txt") { - t.Fatalf("expected node_modules file to be skipped, got %v", visited) - } - err = walkWorkspaceFiles(context.Background(), workdir, filepath.Join(workdir, "missing"), func(path string) error { - return nil - }) - if err == nil { - t.Fatalf("expected walkWorkspaceFiles to return walk error for missing scope") - } - - if normalizeLimit(0, 3, 10) != 3 || normalizeLimit(11, 3, 10) != 10 || normalizeLimit(4, 3, 10) != 4 { - t.Fatalf("normalizeLimit() mismatch") - } - if filepathSlashClean("a/b") != filepath.Clean(filepath.FromSlash("a/b")) { - t.Fatalf("filepathSlashClean() mismatch") - } - if filepathSlashClean(" spaced.go ") != filepath.Clean(filepath.FromSlash(" spaced.go ")) { - t.Fatalf("filepathSlashClean() should not trim spaces") - } - if minInt(1, 2) != 1 || minInt(3, 2) != 2 { - t.Fatalf("minInt() mismatch") - } - }) -} - -func TestRetrieveAndServiceEdgeCases(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "defs.go"), "package pkg\n\ntype Widget struct{}\n\nfunc BuildWidget() {}\nconst WidgetName = \"x\"\nvar WidgetVar = 1\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.txt"), "Widget WidgetName") - - service := newTestService(runGitCommandTestRunner) - - t.Run("retrieve path guards and not exist", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - if _, err := service.retrieveByPath(ctx, workdir, RetrievalQuery{Mode: RetrievalModePath, Value: "pkg/defs.go"}); !errors.Is(err, context.Canceled) { - t.Fatalf("retrieveByPath canceled err = %v", err) - } - - result, err := service.retrieveByPath(context.Background(), workdir, RetrievalQuery{Mode: RetrievalModePath, Value: "pkg/missing.go"}) - if err != nil { - t.Fatalf("retrieveByPath missing err = %v", err) - } - if len(result.Hits) != 0 { - t.Fatalf("expected empty hits for missing file, got %+v", result) - } - }) - - t.Run("retrieve glob/text/symbol helpers", func(t *testing.T) { - t.Parallel() - - _, err := service.retrieveByGlob(context.Background(), workdir, workdir, RetrievalQuery{ - Mode: RetrievalModeGlob, - Value: "[", - Limit: 5, - }) - if err == nil { - t.Fatalf("expected invalid glob pattern error") - } - - textResult, err := service.retrieveByText(context.Background(), workdir, workdir, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "Widget", - Limit: 2, - ContextLines: 1, - }, false) - if err != nil || len(textResult.Hits) == 0 { - t.Fatalf("retrieveByText() = (%+v, %v), want hits", textResult, err) - } - - wordResult, err := service.retrieveByText(context.Background(), workdir, workdir, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "Widget", - Limit: 5, - ContextLines: 1, - }, true) - if err != nil || len(wordResult.Hits) == 0 { - t.Fatalf("retrieveByText wholeWord() = (%+v, %v), want hits", wordResult, err) - } - - symbolResult, err := service.retrieveBySymbol(context.Background(), workdir, workdir, RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "BuildWidget", - Limit: 5, - ContextLines: 1, - }) - if err != nil || len(symbolResult.Hits) == 0 { - t.Fatalf("retrieveBySymbol() = (%+v, %v), want symbol hits", symbolResult, err) - } - - fallbackResult, err := service.retrieveBySymbol(context.Background(), workdir, workdir, RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "WidgetName", - Limit: 5, - ContextLines: 1, - }) - if err != nil || len(fallbackResult.Hits) == 0 { - t.Fatalf("retrieveBySymbol fallback() = (%+v, %v), want hits", fallbackResult, err) - } - for _, hit := range fallbackResult.Hits { - if hit.Kind != string(RetrievalModeSymbol) { - t.Fatalf("expected fallback kind rewritten to symbol, got %+v", hit) - } - } - }) - - t.Run("find symbol definitions and sorting", func(t *testing.T) { - t.Parallel() - - defs := findGoSymbolDefinitions(strings.Join([]string{ - "package p", - "type Widget struct{}", - "func BuildWidget(){}", - "func (s *Svc) BuildWidget(){}", - "const WidgetName = \"x\"", - "var WidgetVar = 1", - "const (", - "WidgetInBlock = 1", - ")", - "var (", - "WidgetVarBlock = 2", - ")", - }, "\n"), "BuildWidget") - if len(defs) < 2 { - t.Fatalf("expected function + method definitions, got %v", defs) - } - if got := findGoSymbolDefinitions("package p", " "); got != nil { - t.Fatalf("expected nil for empty symbol, got %v", got) - } - - hits := []RetrievalHit{ - {Path: "b.go", LineHint: 3}, - {Path: "a.go", LineHint: 8}, - {Path: "a.go", LineHint: 2}, - } - sortRetrievalHits(hits) - if hits[0].Path != "a.go" || hits[0].LineHint != 2 || hits[2].Path != "b.go" { - t.Fatalf("sortRetrievalHits() unexpected order: %+v", hits) - } - }) - - t.Run("summary and changed files error branches", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, err := service.Summary(ctx, workdir) - if !errors.Is(err, context.Canceled) { - t.Fatalf("Summary() err = %v, want context canceled", err) - } - - serviceWithCancelledDiff := &Service{ - gitRunner: func(ctx context.Context, dir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return gitCommandOutput{text: nulJoin("## main", " M pkg/new.go")}, nil - case "diff --unified=3 HEAD -- pkg/new.go": - return gitCommandOutput{}, context.DeadlineExceeded - default: - return gitCommandOutput{}, nil - } - }, - readFile: readFile, - } - mustWriteFile(t, filepath.Join(workdir, "pkg", "new.go"), "package pkg\n\nfunc New(){}\n") - _, err = serviceWithCancelledDiff.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{IncludeSnippets: true}) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("ChangedFiles() err = %v, want deadline exceeded", err) - } - - _, err = service.Retrieve(ctx, workdir, RetrievalQuery{Mode: RetrievalModeText, Value: "Widget"}) - if !errors.Is(err, context.Canceled) { - t.Fatalf("Retrieve() err = %v, want context canceled", err) - } - - if !isNotGitRepository("fatal: not a git repository", errors.New("x")) { - t.Fatalf("expected not-git output to be recognized") - } - if isNotGitRepository("", nil) { - t.Fatalf("expected nil error to return false") - } - if !isContextError(context.Canceled) || !isContextError(context.DeadlineExceeded) || isContextError(errors.New("x")) { - t.Fatalf("isContextError() mismatch") - } - }) -} - -func TestRepositoryCoverageExtraBranches(t *testing.T) { - t.Parallel() - - t.Run("runGitCommand success and failure", func(t *testing.T) { - t.Parallel() - - out, err := runGitCommand(context.Background(), t.TempDir(), gitCommandOptions{}, "--version") - if err != nil { - t.Fatalf("runGitCommand(--version) err = %v", err) - } - if !strings.Contains(strings.ToLower(out.text), "git version") { - t.Fatalf("unexpected git --version output: %q", out.text) - } - - _, err = runGitCommand(context.Background(), t.TempDir(), gitCommandOptions{}, "unknown-subcommand-for-test") - if err == nil { - t.Fatalf("expected runGitCommand invalid subcommand to fail") - } - }) - - t.Run("parse snapshot and counters", func(t *testing.T) { - t.Parallel() - - emptySnapshot := parseGitSnapshot("") - if emptySnapshot.InGitRepo || len(emptySnapshot.Entries) != 0 { - t.Fatalf("parseGitSnapshot(empty) = %+v", emptySnapshot) - } - - snapshot := parseGitSnapshot(nulJoin(" M a.go", "?? b.go")) - if !snapshot.InGitRepo || len(snapshot.Entries) != 2 { - t.Fatalf("parseGitSnapshot(without branch line) = %+v", snapshot) - } - copied := parseGitSnapshot(nulJoin("## main", "C copied.go", "source.go", "?? tail.go")) - if len(copied.Entries) != 2 { - t.Fatalf("expected copy snapshot entries, got %+v", copied) - } - if copied.Entries[0].Status != StatusCopied || copied.Entries[0].Path != filepath.Clean("copied.go") || copied.Entries[0].OldPath != filepath.Clean("source.go") { - t.Fatalf("expected copied entry to parse cleanly, got %+v", copied.Entries[0]) - } - if copied.Entries[1].Path != filepath.Clean("tail.go") { - t.Fatalf("expected following record to stay aligned, got %+v", copied.Entries[1]) - } - quoted := parseGitSnapshot(nulJoin( - ` M dir with space/file name.txt`, - `R dir with space/new name.txt`, - `dir with space/old name.txt`, - )) - if len(quoted.Entries) != 2 { - t.Fatalf("expected quoted-path snapshot entries, got %+v", quoted) - } - if quoted.Entries[0].Path != filepath.Clean("dir with space/file name.txt") { - t.Fatalf("expected clean path with spaces, got %+v", quoted.Entries[0]) - } - if quoted.Entries[1].Path != filepath.Clean("dir with space/new name.txt") || quoted.Entries[1].OldPath != filepath.Clean("dir with space/old name.txt") { - t.Fatalf("expected rename paths with spaces, got %+v", quoted.Entries[1]) - } - - ahead, behind := parseTrackingCounters("main [ahead 2, weird, behind 1, ahead nope]") - if ahead != 2 || behind != 1 { - t.Fatalf("parseTrackingCounters() = (%d,%d), want (2,1)", ahead, behind) - } - ahead, behind = parseTrackingCounters("main []") - if ahead != 0 || behind != 0 { - t.Fatalf("parseTrackingCounters(empty segment) = (%d,%d), want (0,0)", ahead, behind) - } - }) - - t.Run("scope and snippet boundaries", func(t *testing.T) { - t.Parallel() - - root := t.TempDir() - scope, err := resolveScopeDir(root, "") - if err != nil || scope == "" { - t.Fatalf("resolveScopeDir(empty) = (%q, %v)", scope, err) - } - _, err = resolveScopeDir(root, "missing") - if err == nil { - t.Fatalf("expected resolveScopeDir missing path error") - } - - snippet, hint := snippetAroundLine("a\nb\nc", 0, 1) - if hint != 1 || !strings.Contains(snippet, "a") { - t.Fatalf("snippetAroundLine(line<=0) = (%q,%d)", snippet, hint) - } - if _, err := resolveScopeDir(root, ".."); err == nil { - t.Fatalf("expected resolveScopeDir to reject traversal") - } - }) - - t.Run("walk workspace callback and symlink escape", func(t *testing.T) { - t.Parallel() - - root := t.TempDir() - mustWriteFile(t, filepath.Join(root, "a.txt"), "a") - expectedErr := errors.New("stop") - err := walkWorkspaceFiles(context.Background(), root, root, func(path string) error { - return expectedErr - }) - if !errors.Is(err, expectedErr) { - t.Fatalf("walkWorkspaceFiles(callback err) = %v", err) - } - - outsideDir := t.TempDir() - outsideFile := filepath.Join(outsideDir, "secret.txt") - if err := os.WriteFile(outsideFile, []byte("secret"), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - linkPath := filepath.Join(root, "escape.txt") - if err := os.Symlink(outsideFile, linkPath); err == nil { - err = walkWorkspaceFiles(context.Background(), root, root, func(path string) error { - return nil - }) - if err == nil { - t.Fatalf("expected symlink escape error from walkWorkspaceFiles") - } - } - - canceledCtx, cancel := context.WithCancel(context.Background()) - cancel() - err = walkWorkspaceFiles(canceledCtx, root, root, func(path string) error { - return nil - }) - if !errors.Is(err, context.Canceled) { - t.Fatalf("walkWorkspaceFiles(canceled) err = %v", err) - } - }) - - t.Run("retrieve branches and service switches", func(t *testing.T) { - t.Parallel() - - root := t.TempDir() - mustWriteFile(t, filepath.Join(root, "pkg", "defs.go"), strings.Join([]string{ - "package pkg", - "func BuildWidget(){}", - "func BuildWidget2(){}", - "func (s *Svc) BuildWidget(){}", - "const (", - "WidgetName = \"x\"", - ")", - }, "\n")) - mustWriteFile(t, filepath.Join(root, "pkg", "match.txt"), "hit\nhit\nhit") - - svc := newTestService(runGitCommandTestRunner) - canceledCtx, cancel := context.WithCancel(context.Background()) - cancel() - if _, err := svc.retrieveByGlob(canceledCtx, root, root, RetrievalQuery{Mode: RetrievalModeGlob, Value: "*.go", Limit: 1}); !errors.Is(err, context.Canceled) { - t.Fatalf("retrieveByGlob(canceled) err = %v", err) - } - if _, err := svc.retrieveByText(canceledCtx, root, root, RetrievalQuery{Mode: RetrievalModeText, Value: "hit", Limit: 1}, false); !errors.Is(err, context.Canceled) { - t.Fatalf("retrieveByText(canceled) err = %v", err) - } - if _, err := svc.retrieveBySymbol(canceledCtx, root, root, RetrievalQuery{Mode: RetrievalModeSymbol, Value: "BuildWidget", Limit: 1}); !errors.Is(err, context.Canceled) { - t.Fatalf("retrieveBySymbol(canceled) err = %v", err) - } - - // non-not-exist read error branch for retrieveByPath. - failingReadSvc := &Service{ - readFile: func(path string) ([]byte, error) { - return nil, fmt.Errorf("permission denied") - }, - } - _, err := failingReadSvc.retrieveByPath(context.Background(), root, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "pkg/defs.go", - ContextLines: 1, - }) - if err == nil { - t.Fatalf("expected retrieveByPath non-not-exist error") - } - _, err = failingReadSvc.retrieveByGlob(context.Background(), root, root, RetrievalQuery{ - Mode: RetrievalModeGlob, - Value: "*.txt", - Limit: 5, - }) - if err != nil { - t.Fatalf("retrieveByGlob(read err ignored) err = %v", err) - } - _, err = failingReadSvc.retrieveByText(context.Background(), root, root, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "hit", - Limit: 5, - }, false) - if err != nil { - t.Fatalf("retrieveByText(read err ignored) err = %v", err) - } - _, err = failingReadSvc.retrieveBySymbol(context.Background(), root, root, RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "BuildWidget", - Limit: 5, - }) - if err != nil { - t.Fatalf("retrieveBySymbol(read err ignored) err = %v", err) - } - - globResult, err := svc.retrieveByGlob(context.Background(), root, root, RetrievalQuery{ - Mode: RetrievalModeGlob, - Value: "pkg/*.txt", - Limit: 1, - ContextLines: 1, - }) - if err != nil || len(globResult.Hits) != 1 || globResult.Truncated { - t.Fatalf("retrieveByGlob(limit=1) = (%+v, %v)", globResult, err) - } - - textResult, err := svc.retrieveByText(context.Background(), root, root, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "hit", - Limit: 1, - ContextLines: 1, - }, false) - if err != nil || len(textResult.Hits) != 1 || !textResult.Truncated { - t.Fatalf("retrieveByText(limit=1) = (%+v, %v)", textResult, err) - } - - symbolResult, err := svc.retrieveBySymbol(context.Background(), root, root, RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "BuildWidget", - Limit: 1, - ContextLines: 1, - }) - if err != nil || len(symbolResult.Hits) != 1 || !symbolResult.Truncated { - t.Fatalf("retrieveBySymbol(limit=1) = (%+v, %v)", symbolResult, err) - } - - visitedCount := 0 - limitRoot := t.TempDir() - mustWriteFile(t, filepath.Join(limitRoot, "a.txt"), "hit\nhit\n") - limitSvc := &Service{ - readFile: func(path string) ([]byte, error) { - visitedCount++ - return readFile(path) - }, - } - limitedResult, err := limitSvc.retrieveByText(context.Background(), limitRoot, limitRoot, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "hit", - Limit: 1, - ContextLines: 1, - }, false) - if err != nil { - t.Fatalf("retrieveByText(early stop) err = %v", err) - } - if len(limitedResult.Hits) != 1 || !limitedResult.Truncated { - t.Fatalf("expected one limited hit with truncation, got %+v", limitedResult) - } - if visitedCount != 1 { - t.Fatalf("expected retrieval walk to stop after first file, visited %d files", visitedCount) - } - - exactLimitRoot := t.TempDir() - mustWriteFile(t, filepath.Join(exactLimitRoot, "only.txt"), "hit\n") - exactResult, err := svc.retrieveByText(context.Background(), exactLimitRoot, exactLimitRoot, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "hit", - Limit: 1, - ContextLines: 1, - }, false) - if err != nil { - t.Fatalf("retrieveByText(exact limit) err = %v", err) - } - if len(exactResult.Hits) != 1 || exactResult.Truncated { - t.Fatalf("expected one exact-limit hit without truncation, got %+v", exactResult) - } - _, err = svc.retrieveByText(context.Background(), root, filepath.Join(root, "missing"), RetrievalQuery{ - Mode: RetrievalModeText, - Value: "hit", - Limit: 1, - }, true) - if err == nil { - t.Fatalf("expected retrieveByText missing scope error") - } - _, err = svc.retrieveBySymbol(context.Background(), root, filepath.Join(root, "missing"), RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "Unknown", - Limit: 1, - }) - if err == nil { - t.Fatalf("expected retrieveBySymbol missing scope error") - } - - _, err = svc.Retrieve(context.Background(), root, RetrievalQuery{Mode: RetrievalModeGlob, Value: "*.go"}) - if err != nil { - t.Fatalf("Retrieve(glob) err = %v", err) - } - _, err = svc.Retrieve(context.Background(), root, RetrievalQuery{Mode: RetrievalModeText, Value: "BuildWidget"}) - if err != nil { - t.Fatalf("Retrieve(text) err = %v", err) - } - _, err = svc.Retrieve(context.Background(), root, RetrievalQuery{Mode: RetrievalModeSymbol, Value: "BuildWidget"}) - if err != nil { - t.Fatalf("Retrieve(symbol) err = %v", err) - } - _, err = svc.Retrieve(context.Background(), root, RetrievalQuery{Mode: RetrievalMode("invalid"), Value: "BuildWidget"}) - if !errors.Is(err, errInvalidMode) { - t.Fatalf("Retrieve(invalid mode) err = %v", err) - } - }) - - t.Run("summary representative limit and changed-files without snippets", func(t *testing.T) { - t.Parallel() - - service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - lines := []string{"## main"} - for i := 0; i < representativeChangedFilesLimit+2; i++ { - lines = append(lines, fmt.Sprintf(" M file%d.go", i)) - } - return nulJoin(lines...), nil - }) - summary, err := service.Summary(context.Background(), t.TempDir()) - if err != nil { - t.Fatalf("Summary() err = %v", err) - } - if len(summary.RepresentativeChangedFiles) != representativeChangedFilesLimit { - t.Fatalf("expected representative list to be capped at %d, got %d", representativeChangedFilesLimit, len(summary.RepresentativeChangedFiles)) - } - - changed, err := service.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{IncludeSnippets: false}) - if err != nil { - t.Fatalf("ChangedFiles(without snippets) err = %v", err) - } - for _, file := range changed.Files { - if file.Snippet != "" { - t.Fatalf("expected snippet empty when IncludeSnippets=false, got %q", file.Snippet) - } - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - if _, err := service.ChangedFiles(ctx, t.TempDir(), ChangedFilesOptions{}); !errors.Is(err, context.Canceled) { - t.Fatalf("ChangedFiles(canceled) err = %v", err) - } - - nonGitService := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - return "fatal: not a git repository", errors.New("exit status 128") - }) - ctxResult, err := nonGitService.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{}) - if err != nil { - t.Fatalf("ChangedFiles(non-git) err = %v", err) - } - if len(ctxResult.Files) != 0 || ctxResult.TotalCount != 0 || ctxResult.ReturnedCount != 0 { - t.Fatalf("expected empty changed-files for non-git dir, got %+v", ctxResult) - } - }) -} diff --git a/internal/context/repository/repository_test.go b/internal/context/repository/repository_test.go deleted file mode 100644 index 020c7544..00000000 --- a/internal/context/repository/repository_test.go +++ /dev/null @@ -1,697 +0,0 @@ -package repository - -import ( - "context" - "errors" - "os" - "path/filepath" - "strconv" - "strings" - "testing" -) - -func TestSummaryReturnsStableEmptyForNonGitDirectory(t *testing.T) { - t.Parallel() - - service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - return "fatal: not a git repository", errors.New("exit status 128") - }) - - summary, err := service.Summary(context.Background(), t.TempDir()) - if err != nil { - t.Fatalf("Summary() error = %v", err) - } - if summary.InGitRepo { - t.Fatalf("expected non-git summary, got %+v", summary) - } -} - -func TestSummaryParsesBranchDirtyAheadBehindAndRepresentativeFiles(t *testing.T) { - t.Parallel() - - service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - return nulJoin( - "## feature/repository...origin/feature/repository [ahead 2, behind 1]", - " M internal/context/source_system.go", - "R new/name.go", - "old/name.go", - "?? internal/context/repository/service.go", - ), nil - }) - - summary, err := service.Summary(context.Background(), t.TempDir()) - if err != nil { - t.Fatalf("Summary() error = %v", err) - } - if !summary.InGitRepo || !summary.Dirty { - t.Fatalf("expected git repo summary, got %+v", summary) - } - if summary.Branch != "feature/repository" { - t.Fatalf("expected branch parsed, got %q", summary.Branch) - } - if summary.Ahead != 2 || summary.Behind != 1 { - t.Fatalf("expected ahead=2 behind=1, got %+v", summary) - } - if summary.ChangedFileCount != 3 { - t.Fatalf("expected 3 changed files, got %d", summary.ChangedFileCount) - } - expected := []string{ - filepath.Clean("internal/context/source_system.go"), - filepath.Clean("new/name.go"), - filepath.Clean("internal/context/repository/service.go"), - } - for index, path := range expected { - if summary.RepresentativeChangedFiles[index] != path { - t.Fatalf("expected representative path %q, got %q", path, summary.RepresentativeChangedFiles[index]) - } - } -} - -func TestInspectSharesSnapshotForSummaryAndChangedFiles(t *testing.T) { - t.Parallel() - - calls := 0 - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - calls++ - if strings.Join(args, " ") != "status --porcelain=v1 -z --branch --untracked-files=normal" { - t.Fatalf("unexpected git args: %v", args) - } - return gitCommandOutput{text: nulJoin("## main", " M internal/runtime/run.go")}, nil - }, - readFile: readFile, - } - - result, err := service.Inspect(context.Background(), t.TempDir(), InspectOptions{ - ChangedFilesLimit: 10, - }) - if err != nil { - t.Fatalf("Inspect() error = %v", err) - } - if calls != 1 { - t.Fatalf("expected Inspect() to load a single snapshot, got %d calls", calls) - } - if !result.Summary.InGitRepo || result.Summary.Branch != "main" { - t.Fatalf("unexpected summary: %+v", result.Summary) - } - if result.ChangedFiles.TotalCount != 1 || len(result.ChangedFiles.Files) != 1 { - t.Fatalf("unexpected changed-files context: %+v", result.ChangedFiles) - } -} - -func TestChangedFilesRespectsStatusNormalizationAndSnippetRules(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - if err := os.MkdirAll(filepath.Join(workdir, "pkg"), 0o755); err != nil { - t.Fatalf("MkdirAll() error = %v", err) - } - if err := os.WriteFile(filepath.Join(workdir, "pkg", "changed.go"), []byte("package pkg\n\nfunc Changed() {}\n"), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - if err := os.WriteFile(filepath.Join(workdir, "pkg", "new.go"), []byte("package pkg\n\nfunc Added() {}\n"), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - if err := os.WriteFile(filepath.Join(workdir, "pkg", "untracked.go"), []byte("package pkg\n\nfunc Untracked() {}\n"), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - if err := os.WriteFile(filepath.Join(workdir, "pkg", "renamed.go"), []byte("package pkg\n\nfunc Renamed() {}\n"), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - - service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return nulJoin( - "## main...origin/main [ahead 1]", - " M pkg/changed.go", - "A pkg/new.go", - "?? pkg/untracked.go", - "D pkg/deleted.go", - "R pkg/renamed.go", - "pkg/old.go", - "C pkg/copied.go", - "pkg/source.go", - "UU pkg/conflicted.go", - ), nil - case "diff --unified=3 HEAD -- pkg/changed.go": - return "@@ -1,1 +1,1 @@\n-func Old() {}\n+func Changed() {}\n", nil - case "diff --unified=3 HEAD -- pkg/new.go": - return "@@ -0,0 +1,3 @@\n+package pkg\n+\n+func Added() {}\n", nil - case "diff --unified=3 HEAD -- pkg/renamed.go": - return "", nil - default: - return "", nil - } - }) - - ctx, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - }) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - if ctx.TotalCount != 7 || ctx.ReturnedCount != 7 { - t.Fatalf("unexpected count summary: %+v", ctx) - } - assertChangedFile(t, ctx.Files[0], filepath.Clean("pkg/changed.go"), "", StatusModified, "Changed") - assertChangedFile(t, ctx.Files[1], filepath.Clean("pkg/new.go"), "", StatusAdded, "Added") - assertChangedFile(t, ctx.Files[2], filepath.Clean("pkg/untracked.go"), "", StatusUntracked, "Untracked") - assertChangedFile(t, ctx.Files[3], filepath.Clean("pkg/deleted.go"), "", StatusDeleted, "") - assertChangedFile(t, ctx.Files[4], filepath.Clean("pkg/renamed.go"), filepath.Clean("pkg/old.go"), StatusRenamed, "") - assertChangedFile(t, ctx.Files[5], filepath.Clean("pkg/copied.go"), filepath.Clean("pkg/source.go"), StatusCopied, "") - assertChangedFile(t, ctx.Files[6], filepath.Clean("pkg/conflicted.go"), "", StatusConflicted, "") -} - -func TestChangedFilesAppliesLimitAndTruncation(t *testing.T) { - t.Parallel() - - service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - lines := []string{"## main"} - for i := 0; i < 60; i++ { - lines = append(lines, " M file"+strconv.Itoa(i)+".go") - } - return nulJoin(lines...), nil - }) - - result, err := service.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{}) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - if !result.Truncated { - t.Fatalf("expected truncation for oversized changed files list") - } - if result.ReturnedCount != defaultChangedFilesLimit { - t.Fatalf("expected default limit %d, got %d", defaultChangedFilesLimit, result.ReturnedCount) - } -} - -func TestChangedFilesMarksTruncatedWhenSingleSnippetExceedsLineLimit(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "long.go"), "package pkg\n") - service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return nulJoin("## main", " M pkg/long.go"), nil - case "diff --unified=3 HEAD -- pkg/long.go": - lines := []string{"@@ -1,1 +1,25 @@"} - for i := 0; i < 25; i++ { - lines = append(lines, "+line "+strconv.Itoa(i)) - } - return strings.Join(lines, "\n"), nil - default: - return "", nil - } - }) - - result, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - }) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - if !result.Truncated { - t.Fatalf("expected snippet truncation to set Truncated") - } - if got := len(splitNonEmptyLines(result.Files[0].Snippet)); got != maxChangedSnippetLinesPerFile { - t.Fatalf("expected snippet to be trimmed to %d lines, got %d", maxChangedSnippetLinesPerFile, got) - } -} - -func TestChangedFilesMarksTruncatedWhenTotalSnippetBudgetExceeded(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - lines := make([]string, 0, maxChangedSnippetLinesPerFile+2) - lines = append(lines, "package pkg") - for i := 0; i < maxChangedSnippetLinesPerFile+1; i++ { - lines = append(lines, "line "+strconv.Itoa(i)) - } - content := strings.Join(lines, "\n") - - statusLines := []string{"## main"} - for i := 0; i < 11; i++ { - fileName := filepath.Join("pkg", "file"+strconv.Itoa(i)+".txt") - mustWriteFile(t, filepath.Join(workdir, fileName), content) - statusLines = append(statusLines, "?? "+filepath.ToSlash(fileName)) - } - - service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { - if strings.Join(args, " ") == "status --porcelain=v1 -z --branch --untracked-files=normal" { - return nulJoin(statusLines...), nil - } - return "", nil - }) - - result, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - }) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - if !result.Truncated { - t.Fatalf("expected total snippet budget truncation to set Truncated") - } - last := result.Files[len(result.Files)-1] - if last.Snippet != "" { - t.Fatalf("expected last snippet to be dropped after total budget is exhausted, got %q", last.Snippet) - } -} - -func TestChangedFilesBlocksSensitiveLargeAndBinarySnippets(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, ".env"), "API_KEY=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".envrc"), "export API_KEY=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".npmrc"), "token=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".aws", "credentials"), "[default]\naws_access_key_id=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".ssh", "id_rsa"), "PRIVATE KEY\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "cert.pem"), "-----BEGIN PRIVATE KEY-----\nsecret\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "issuer.p8"), "PRIVATE KEY\n") - mustWriteFile(t, filepath.Join(workdir, "config", "secrets.yml"), "token: secret\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "secrets.txt"), "secret dump\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "private-secrets.md"), "private material\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "token_dump.log"), "token dump\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "credentials.json"), "{\"token\":\"secret\"}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "secrets"), "secret dump\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "bin.dat"), string([]byte{0x00, 0x01, 0x02})) - mustWriteFile(t, filepath.Join(workdir, "pkg", "large.txt"), strings.Repeat("x", maxRepositorySnippetFileBytes+1)) - - service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { - if strings.Join(args, " ") == "status --porcelain=v1 -z --branch --untracked-files=normal" { - return nulJoin( - "## main", - "?? .env", - "?? .envrc", - "?? .npmrc", - "?? .aws/credentials", - "?? .ssh/id_rsa", - "?? pkg/cert.pem", - "?? pkg/issuer.p8", - "?? config/secrets.yml", - "?? pkg/secrets.txt", - "?? pkg/private-secrets.md", - "?? pkg/token_dump.log", - "?? pkg/credentials.json", - "?? pkg/secrets", - "?? pkg/bin.dat", - "?? pkg/large.txt", - ), nil - } - return "", nil - }) - - result, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - }) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - for _, file := range result.Files { - if file.Snippet != "" { - t.Fatalf("expected filtered file to have empty snippet, got %+v", file) - } - } -} - -func TestChangedFilesBlocksModifiedSensitiveDiffSnippet(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, ".env"), "API_KEY=secret\n") - mustWriteFile(t, filepath.Join(workdir, "config", "secrets.yaml"), "token: secret\n") - - service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return nulJoin("## main", " M .env", " M config/secrets.yaml"), nil - case "diff --unified=3 HEAD -- .env": - return "@@ -1,1 +1,1 @@\n-API_KEY=old\n+API_KEY=new\n", nil - case "diff --unified=3 HEAD -- config/secrets.yaml": - return "@@ -1,1 +1,1 @@\n-token: old\n+token: new\n", nil - default: - return "", nil - } - }) - - result, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - }) - if err != nil { - t.Fatalf("ChangedFiles() error = %v", err) - } - if len(result.Files) != 2 { - t.Fatalf("expected two changed files, got %+v", result.Files) - } - for _, file := range result.Files { - if file.Snippet != "" { - t.Fatalf("expected sensitive modified file to have empty snippet, got %+v", file) - } - } -} - -func TestChangedFilesRespectsSnippetFileCountLimit(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "one.go"), "package pkg\n\nfunc One() {}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "two.go"), "package pkg\n\nfunc Two() {}\n") - - service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { - switch strings.Join(args, " ") { - case "status --porcelain=v1 -z --branch --untracked-files=normal": - return nulJoin("## main", "?? pkg/one.go", "?? pkg/two.go"), nil - default: - return "", nil - } - }) - - allowed, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - SnippetFileCountLimit: 2, - }) - if err != nil { - t.Fatalf("ChangedFiles() allow error = %v", err) - } - if allowed.Files[0].Snippet == "" || allowed.Files[1].Snippet == "" { - t.Fatalf("expected snippets when total count does not exceed limit, got %+v", allowed.Files) - } - - blocked, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ - IncludeSnippets: true, - SnippetFileCountLimit: 1, - }) - if err != nil { - t.Fatalf("ChangedFiles() block error = %v", err) - } - if blocked.Files[0].Snippet != "" || blocked.Files[1].Snippet != "" { - t.Fatalf("expected snippets to be suppressed after count limit, got %+v", blocked.Files) - } -} - -func TestSummaryReturnsErrorForUnexpectedGitFailure(t *testing.T) { - t.Parallel() - - service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { - return "fatal: permission denied", errors.New("exit status 128") - }) - - _, err := service.Summary(context.Background(), t.TempDir()) - if err == nil { - t.Fatalf("expected unexpected git failure to be returned") - } -} - -func TestRetrieveSupportsPathGlobTextAndSymbol(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "target.go"), "package pkg\n\ntype Widget struct{}\n\nfunc BuildWidget() Widget {\n\treturn Widget{}\n}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.txt"), "Widget appears here too\n") - - service := NewService() - - pathResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "pkg/target.go", - }) - if err != nil { - t.Fatalf("Retrieve(path) error = %v", err) - } - if len(pathResult.Hits) != 1 || pathResult.Hits[0].Kind != string(RetrievalModePath) || pathResult.Truncated { - t.Fatalf("unexpected path result: %+v", pathResult) - } - - globResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeGlob, - Value: "*.go", - }) - if err != nil { - t.Fatalf("Retrieve(glob) error = %v", err) - } - if len(globResult.Hits) == 0 { - t.Fatalf("expected glob hits") - } - - textResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "Widget", - }) - if err != nil { - t.Fatalf("Retrieve(text) error = %v", err) - } - if len(textResult.Hits) < 2 { - t.Fatalf("expected text hits across files, got %+v", textResult) - } - - symbolResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "BuildWidget", - }) - if err != nil { - t.Fatalf("Retrieve(symbol) error = %v", err) - } - if len(symbolResult.Hits) != 1 || symbolResult.Hits[0].LineHint <= 0 { - t.Fatalf("unexpected symbol hits: %+v", symbolResult) - } -} - -func TestRetrieveRejectsPathEscapeAndSymlinkEscape(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - outsideDir := t.TempDir() - outsideFile := filepath.Join(outsideDir, "secret.txt") - if err := os.WriteFile(outsideFile, []byte("secret"), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } - - if err := os.MkdirAll(filepath.Join(workdir, "pkg"), 0o755); err != nil { - t.Fatalf("MkdirAll() error = %v", err) - } - linkPath := filepath.Join(workdir, "pkg", "outside.txt") - if err := os.Symlink(outsideFile, linkPath); err != nil { - t.Skipf("symlink not available: %v", err) - } - - service := NewService() - - _, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "..\\outside.txt", - }) - if err == nil { - t.Fatalf("expected path traversal to be rejected") - } - - _, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "pkg/outside.txt", - }) - if err == nil { - t.Fatalf("expected symlink escape to be rejected") - } -} - -func TestRetrieveSymbolFallsBackToWholeWordTextSearch(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.txt"), "searchWidget searchWidget\n") - - service := NewService() - result, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeSymbol, - Value: "searchWidget", - }) - if err != nil { - t.Fatalf("Retrieve(symbol fallback) error = %v", err) - } - if len(result.Hits) != 1 { - t.Fatalf("expected fallback whole-word hit, got %+v", result) - } -} - -func TestRetrieveSkipsSensitiveLargeAndBinaryFiles(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustWriteFile(t, filepath.Join(workdir, ".env"), "API_KEY=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".envrc"), "export TOKEN=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".npmrc"), "token=secret\n") - mustWriteFile(t, filepath.Join(workdir, ".aws", "credentials"), "[default]\naws_access_key_id=secret\n") - mustWriteFile(t, filepath.Join(workdir, "config", "secrets.yml"), "token: secret\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "issuer.p8"), "PRIVATE KEY\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.key"), "private") - mustWriteFile(t, filepath.Join(workdir, "pkg", "secrets.txt"), "secret dump\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "private-secrets.md"), "private material\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "token_dump.log"), "token dump\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "credentials.json"), "{\"token\":\"secret\"}\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "secrets"), "secret dump\n") - mustWriteFile(t, filepath.Join(workdir, "pkg", "bin.dat"), string([]byte{0x00, 0x01, 0x02, 0x03})) - mustWriteFile(t, filepath.Join(workdir, "pkg", "target.txt"), "match line\n") - - largeContent := strings.Repeat("x", maxRepositorySnippetFileBytes+1) - mustWriteFile(t, filepath.Join(workdir, "pkg", "large.txt"), largeContent) - - service := NewService() - - pathResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: ".env", - }) - if err != nil { - t.Fatalf("Retrieve(path sensitive) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected sensitive path retrieval to be filtered, got %+v", pathResult) - } - pathResult, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: ".npmrc", - }) - if err != nil { - t.Fatalf("Retrieve(path npmrc) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected .npmrc retrieval to be filtered, got %+v", pathResult) - } - pathResult, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: ".aws/credentials", - }) - if err != nil { - t.Fatalf("Retrieve(path aws credentials) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected aws credentials retrieval to be filtered, got %+v", pathResult) - } - pathResult, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: ".envrc", - }) - if err != nil { - t.Fatalf("Retrieve(path envrc) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected .envrc retrieval to be filtered, got %+v", pathResult) - } - pathResult, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "config/secrets.yml", - }) - if err != nil { - t.Fatalf("Retrieve(path secrets) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected secrets.yml retrieval to be filtered, got %+v", pathResult) - } - pathResult, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: "pkg/issuer.p8", - }) - if err != nil { - t.Fatalf("Retrieve(path p8) error = %v", err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected .p8 retrieval to be filtered, got %+v", pathResult) - } - for _, blocked := range []string{ - "pkg/secrets.txt", - "pkg/private-secrets.md", - "pkg/token_dump.log", - "pkg/credentials.json", - "pkg/secrets", - } { - pathResult, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModePath, - Value: blocked, - }) - if err != nil { - t.Fatalf("Retrieve(path %s) error = %v", blocked, err) - } - if len(pathResult.Hits) != 0 { - t.Fatalf("expected %s retrieval to be filtered, got %+v", blocked, pathResult) - } - } - - textResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeText, - Value: "match", - Limit: 10, - }) - if err != nil { - t.Fatalf("Retrieve(text) error = %v", err) - } - if len(textResult.Hits) != 1 || textResult.Hits[0].Path != filepath.Clean("pkg/target.txt") { - t.Fatalf("expected only safe text hit, got %+v", textResult) - } - - globResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ - Mode: RetrievalModeGlob, - Value: "pkg/*", - Limit: 10, - }) - if err != nil { - t.Fatalf("Retrieve(glob) error = %v", err) - } - for _, hit := range globResult.Hits { - if hit.Path == filepath.Clean("pkg/large.txt") || - hit.Path == filepath.Clean("pkg/notes.key") || - hit.Path == filepath.Clean("pkg/bin.dat") || - hit.Path == filepath.Clean("pkg/issuer.p8") { - t.Fatalf("expected filtered file to be excluded, got %+v", globResult) - } - } -} - -func assertChangedFile(t *testing.T, file ChangedFile, path string, oldPath string, status ChangedFileStatus, snippetContains string) { - t.Helper() - if file.Path != path || file.OldPath != oldPath || file.Status != status { - t.Fatalf("unexpected changed file: %+v", file) - } - if snippetContains == "" { - if file.Snippet != "" { - t.Fatalf("expected empty snippet, got %q", file.Snippet) - } - return - } - if !strings.Contains(file.Snippet, snippetContains) { - t.Fatalf("expected snippet to contain %q, got %q", snippetContains, file.Snippet) - } -} - -func mustWriteFile(t *testing.T, path string, content string) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - t.Fatalf("MkdirAll() error = %v", err) - } - if err := os.WriteFile(path, []byte(content), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) - } -} - -func newTestService(gitRunner func(ctx context.Context, workdir string, args ...string) (string, error)) *Service { - return &Service{ - gitRunner: func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { - output, err := gitRunner(ctx, workdir, args...) - return gitCommandOutput{text: output}, err - }, - readFile: readFile, - } -} - -func runGitCommandTestRunner(ctx context.Context, workdir string, args ...string) (string, error) { - output, err := runGitCommand(ctx, workdir, gitCommandOptions{}, args...) - return output.text, err -} - -func nulJoin(records ...string) string { - if len(records) == 0 { - return "" - } - return strings.Join(records, "\x00") + "\x00" -} diff --git a/internal/context/source_repository_test.go b/internal/context/source_repository_test.go index edb93b85..a0f31690 100644 --- a/internal/context/source_repository_test.go +++ b/internal/context/source_repository_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "neo-code/internal/context/repository" + "neo-code/internal/repository" ) func TestRepositoryContextSourceSkipsEmptyRepositoryContext(t *testing.T) { diff --git a/internal/context/types.go b/internal/context/types.go index 244060a3..15ed99ea 100644 --- a/internal/context/types.go +++ b/internal/context/types.go @@ -3,7 +3,7 @@ package context import ( "context" - "neo-code/internal/context/repository" + "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" "neo-code/internal/skills" diff --git a/internal/gateway/validate.go b/internal/gateway/validate.go index 901bd65e..581b8541 100644 --- a/internal/gateway/validate.go +++ b/internal/gateway/validate.go @@ -68,7 +68,9 @@ func validateRequestFrame(frame MessageFrame) *FrameError { FrameActionListSessionTodos, FrameActionGetRuntimeSnapshot, FrameActionDeleteSession, - FrameActionGetSessionModel: + FrameActionGetSessionModel, + FrameActionListCheckpoints, + FrameActionUndoRestore: if strings.TrimSpace(frame.SessionID) == "" { return NewMissingRequiredFieldError("session_id") } @@ -86,6 +88,15 @@ func validateRequestFrame(frame MessageFrame) *FrameError { return nil case FrameActionResolvePermission: return validateResolvePermissionFrame(frame) + case FrameActionRestoreCheckpoint, + FrameActionCheckpointDiff: + if frame.Payload == nil { + return NewMissingRequiredFieldError("payload") + } + if strings.TrimSpace(frame.SessionID) == "" { + return NewMissingRequiredFieldError("session_id") + } + return nil default: return NewFrameError(ErrorCodeInvalidAction, "invalid action") } @@ -455,6 +466,10 @@ func isValidFrameAction(action FrameAction) bool { FrameActionUpsertMCPServer, FrameActionSetMCPServerEnabled, FrameActionDeleteMCPServer, + FrameActionListCheckpoints, + FrameActionRestoreCheckpoint, + FrameActionUndoRestore, + FrameActionCheckpointDiff, FrameActionWorkspaceList, FrameActionWorkspaceCreate, FrameActionWorkspaceSwitch, diff --git a/internal/promptasset/templates/core/tool_usage.md b/internal/promptasset/templates/core/tool_usage.md index 410c4ea2..c720174d 100644 --- a/internal/promptasset/templates/core/tool_usage.md +++ b/internal/promptasset/templates/core/tool_usage.md @@ -9,6 +9,19 @@ - Use `filesystem_grep` to locate symbols, strings, and relevant code paths efficiently. - Read tool results carefully before acting. Treat `status`, `ok`, `tool_call_id`, `truncated`, `meta.*`, exit codes, and `content` as the authoritative model-visible outcome of that call. +## Repository exploration +When exploring the codebase, Git state, or current changes: +1. Use `git_summary` to understand repository state (branch, dirty, ahead/behind). +2. Use `git_changed_files` to list modified/added/deleted files without snippets. +3. Use `git_changed_snippets` when you need to see actual diff content of changes. +4. Use `codebase_search_symbol` to find symbol definitions (returns path, line hint, kind, and signature only). +5. Use `codebase_search_text` to find text matches across files (returns path, line hint, and match count only). +6. Use `codebase_read` to read actual file content when you need implementation details. + +Important: `codebase_search_symbol` and `codebase_search_text` do NOT return code bodies. They only return pointers/locations. You must call `codebase_read` to see the actual implementation. + +For general file operations outside of codebase exploration, use `filesystem_*` tools as usual. + ## Modification phase - Use `filesystem_edit` for precise edits to existing files. - Use `filesystem_write_file` only for new files or full rewrites. @@ -66,7 +79,7 @@ - Whenever a `filesystem_*` tool can express the operation, use it instead of `bash`. The runtime tracks `filesystem_*` operations precisely; `bash` mutations are tracked only via best-effort heuristics + workdir scanning, so undoing them is less reliable. - When using `bash`, avoid interactive or blocking commands and pass non-interactive flags when they are available. - Stay within the current workspace unless the user clearly asks for something else. -- Use Git through `bash` with this order: inspect (`git status`/`git diff`/`git log`), then mutate, then verify (`git status`/`git diff`), then summarize. +- Use Git through dedicated `git_*` tools (`git_summary`, `git_changed_files`, `git_changed_snippets`) for inspection; use `bash` only for Git mutations (commit, push, etc.) or when the dedicated tools do not cover the need. - Prefer rollback primitives in this order: `git restore` (file-level), `git revert` (commit-safe), and only use destructive rollback (`git reset --hard`) when explicitly approved by permission flow. ## Permission and decision flow diff --git a/internal/checkpoint/fingerprint.go b/internal/repository/fingerprint.go similarity index 99% rename from internal/checkpoint/fingerprint.go rename to internal/repository/fingerprint.go index 1d335026..af9ef719 100644 --- a/internal/checkpoint/fingerprint.go +++ b/internal/repository/fingerprint.go @@ -1,4 +1,4 @@ -package checkpoint +package repository import ( "context" diff --git a/internal/checkpoint/fingerprint_test.go b/internal/repository/fingerprint_test.go similarity index 99% rename from internal/checkpoint/fingerprint_test.go rename to internal/repository/fingerprint_test.go index 7170c358..cfcc5bb7 100644 --- a/internal/checkpoint/fingerprint_test.go +++ b/internal/repository/fingerprint_test.go @@ -1,4 +1,4 @@ -package checkpoint +package repository import ( "context" diff --git a/internal/context/repository/git.go b/internal/repository/git.go similarity index 92% rename from internal/context/repository/git.go rename to internal/repository/git.go index a51c4cc5..2523b77b 100644 --- a/internal/context/repository/git.go +++ b/internal/repository/git.go @@ -23,16 +23,16 @@ const ( maxChangedDiffBytes = 64 * 1024 ) -type gitCommandOptions struct { +type GitCommandOptions struct { MaxOutputBytes int } -type gitCommandOutput struct { - text string - truncated bool +type GitCommandOutput struct { + Text string + Truncated bool } -type gitCommandRunner func(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) +type GitCommandRunner func(ctx context.Context, workdir string, opts GitCommandOptions, args ...string) (GitCommandOutput, error) type gitSnapshot struct { InGitRepo bool @@ -57,18 +57,18 @@ func (s *Service) loadGitSnapshot(ctx context.Context, workdir string) (gitSnaps return gitSnapshot{}, nil } - output, err := s.gitRunner(ctx, workdir, gitCommandOptions{}, "status", "--porcelain=v1", "-z", "--branch", "--untracked-files=normal") + output, err := s.gitRunner(ctx, workdir, GitCommandOptions{}, "status", "--porcelain=v1", "-z", "--branch", "--untracked-files=normal") if err != nil { if isContextError(err) { return gitSnapshot{}, err } - if isNotGitRepository(output.text, err) || isAmbiguousGitStatusOutsideRepo(workdir, output.text, err) { + if isNotGitRepository(output.Text, err) || isAmbiguousGitStatusOutsideRepo(workdir, output.Text, err) { return gitSnapshot{}, nil } return gitSnapshot{}, err } - return parseGitSnapshot(output.text), nil + return parseGitSnapshot(output.Text), nil } // changedFileSnippet 按固定语义为单个变更条目生成受限片段。 @@ -99,15 +99,15 @@ func (s *Service) readDiffSnippet(ctx context.Context, workdir string, path stri if !allowed { return snippetResult{}, nil } - output, err := s.gitRunner(ctx, workdir, gitCommandOptions{MaxOutputBytes: maxChangedDiffBytes}, "diff", "--unified=3", "HEAD", "--", filepath.ToSlash(path)) + output, err := s.gitRunner(ctx, workdir, GitCommandOptions{MaxOutputBytes: maxChangedDiffBytes}, "diff", "--unified=3", "HEAD", "--", filepath.ToSlash(path)) if err != nil { if isContextError(err) { return snippetResult{}, err } return snippetResult{}, err } - snippet := trimSnippetText(output.text, maxChangedSnippetLinesPerFile) - if output.truncated { + snippet := trimSnippetText(output.Text, maxChangedSnippetLinesPerFile) + if output.Truncated { snippet.truncated = true } return snippet, nil @@ -291,7 +291,7 @@ func normalizeStatus(x byte, y byte) ChangedFileStatus { } // runGitCommand 统一执行 git 子命令,并在超时后主动取消。 -func runGitCommand(ctx context.Context, workdir string, opts gitCommandOptions, args ...string) (gitCommandOutput, error) { +func runGitCommand(ctx context.Context, workdir string, opts GitCommandOptions, args ...string) (GitCommandOutput, error) { timeoutCtx, cancel := context.WithTimeout(ctx, gitCommandTimeout) defer cancel() @@ -300,7 +300,7 @@ func runGitCommand(ctx context.Context, workdir string, opts gitCommandOptions, command.Stdout = buffer command.Stderr = io.MultiWriter(buffer) err := command.Run() - return gitCommandOutput{text: buffer.String(), truncated: buffer.truncated}, err + return GitCommandOutput{Text: buffer.String(), Truncated: buffer.truncated}, err } type gitOutputBuffer struct { diff --git a/internal/context/repository/path.go b/internal/repository/path.go similarity index 94% rename from internal/context/repository/path.go rename to internal/repository/path.go index 92cc01a6..b88d2fe3 100644 --- a/internal/context/repository/path.go +++ b/internal/repository/path.go @@ -14,9 +14,16 @@ import ( var errInvalidMode = errors.New("repository: invalid retrieval mode") -type fileReader func(path string) ([]byte, error) - -const maxSnippetLineRunes = 512 +type FileReader func(path string) ([]byte, error) + +const ( + defaultRetrievalLimit = 20 + maxRetrievalLimit = 50 + defaultContextLines = 3 + maxContextLines = 8 + maxSnippetLines = 20 + maxSnippetLineRunes = 512 +) // normalizeRetrievalQuery 统一校验检索请求并补齐默认值。 func normalizeRetrievalQuery(workdir string, query RetrievalQuery) (string, string, RetrievalQuery, error) { @@ -45,11 +52,6 @@ func normalizeRetrievalQuery(workdir string, query RetrievalQuery) (string, stri return root, scope, normalized, nil } -// resolveWorkspacePath 将工作区内的相对路径解析为绝对路径并校验边界。 -func resolveWorkspacePath(workdir string, relativePath string) (string, string, error) { - return security.ResolveWorkspacePath(workdir, relativePath) -} - // resolveScopeDir 解析检索范围目录,空值时返回整个工作区根。 func resolveScopeDir(root string, scopeDir string) (string, error) { _, target, err := security.ResolveWorkspacePath(root, scopeDir) @@ -215,3 +217,7 @@ func minInt(a int, b int) int { } return b } + +func readFile(path string) ([]byte, error) { + return os.ReadFile(path) +} diff --git a/internal/repository/repository_coverage_test.go b/internal/repository/repository_coverage_test.go new file mode 100644 index 00000000..3f6e5307 --- /dev/null +++ b/internal/repository/repository_coverage_test.go @@ -0,0 +1,556 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + "testing" +) + +func TestRepositoryServiceSummaryChangedFilesAndRetrieve(t *testing.T) { + t.Parallel() + + t.Run("summary non git and parsed git snapshot", func(t *testing.T) { + t.Parallel() + + service := newRepositoryTestService(func(ctx context.Context, workdir string, args ...string) (GitCommandOutput, error) { + return GitCommandOutput{Text: "fatal: not a git repository"}, errors.New("exit status 128") + }) + summary, err := service.Summary(context.Background(), t.TempDir()) + if err != nil { + t.Fatalf("Summary() error = %v", err) + } + if summary.InGitRepo { + t.Fatalf("expected non-git summary, got %+v", summary) + } + + service = newRepositoryTestService(func(ctx context.Context, workdir string, args ...string) (GitCommandOutput, error) { + return GitCommandOutput{Text: nulJoin( + "## feature/repository...origin/feature/repository [ahead 2, behind 1]", + " M pkg/changed.go", + "R pkg/new.go", + "pkg/old.go", + "?? pkg/untracked.go", + )}, nil + }) + summary, err = service.Summary(context.Background(), t.TempDir()) + if err != nil { + t.Fatalf("Summary() error = %v", err) + } + if !summary.InGitRepo || !summary.Dirty { + t.Fatalf("expected git repo summary, got %+v", summary) + } + if summary.Branch != "feature/repository" || summary.Ahead != 2 || summary.Behind != 1 { + t.Fatalf("unexpected summary counters: %+v", summary) + } + if summary.ChangedFileCount != 3 { + t.Fatalf("expected 3 changed files, got %d", summary.ChangedFileCount) + } + }) + + t.Run("inspect reuses snapshot and changed files respect snippet rules", func(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "changed.go"), "package pkg\n\nfunc Changed() {}\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "new.go"), "package pkg\n\nfunc Added() {}\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "untracked.go"), "package pkg\n\nfunc Untracked() {}\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "renamed.go"), "package pkg\n\nfunc Renamed() {}\n") + + calls := 0 + service := newRepositoryTestService(func(ctx context.Context, dir string, args ...string) (GitCommandOutput, error) { + calls++ + switch strings.Join(args, " ") { + case "status --porcelain=v1 -z --branch --untracked-files=normal": + return GitCommandOutput{Text: nulJoin( + "## main...origin/main [ahead 1]", + " M pkg/changed.go", + "A pkg/new.go", + "?? pkg/untracked.go", + "D pkg/deleted.go", + "R pkg/renamed.go", + "pkg/old.go", + "C pkg/copied.go", + "pkg/source.go", + "UU pkg/conflicted.go", + )}, nil + case "diff --unified=3 HEAD -- pkg/changed.go": + return GitCommandOutput{Text: "@@ -1,1 +1,1 @@\n-func Old() {}\n+func Changed() {}\n"}, nil + case "diff --unified=3 HEAD -- pkg/new.go": + return GitCommandOutput{Text: "@@ -0,0 +1,3 @@\n+package pkg\n+\n+func Added() {}\n"}, nil + case "diff --unified=3 HEAD -- pkg/renamed.go": + return GitCommandOutput{}, nil + default: + return GitCommandOutput{}, nil + } + }) + + result, err := service.Inspect(context.Background(), workdir, InspectOptions{ + ChangedFilesLimit: 10, + IncludeChangedFileSnippets: true, + ChangedFileSnippetFileCountLimit: 10, + }) + if err != nil { + t.Fatalf("Inspect() error = %v", err) + } + if calls != 3 { + t.Fatalf("expected one status + two diff snippet calls, got %d", calls) + } + if result.Summary.Branch != "main" || result.ChangedFiles.TotalCount != 7 { + t.Fatalf("unexpected inspect result: %+v", result) + } + assertChangedRepositoryFile(t, result.ChangedFiles.Files[0], filepath.Clean("pkg/changed.go"), "", StatusModified, "Changed") + assertChangedRepositoryFile(t, result.ChangedFiles.Files[1], filepath.Clean("pkg/new.go"), "", StatusAdded, "Added") + assertChangedRepositoryFile(t, result.ChangedFiles.Files[2], filepath.Clean("pkg/untracked.go"), "", StatusUntracked, "Untracked") + assertChangedRepositoryFile(t, result.ChangedFiles.Files[3], filepath.Clean("pkg/deleted.go"), "", StatusDeleted, "") + assertChangedRepositoryFile(t, result.ChangedFiles.Files[4], filepath.Clean("pkg/renamed.go"), filepath.Clean("pkg/old.go"), StatusRenamed, "") + assertChangedRepositoryFile(t, result.ChangedFiles.Files[5], filepath.Clean("pkg/copied.go"), filepath.Clean("pkg/source.go"), StatusCopied, "") + assertChangedRepositoryFile(t, result.ChangedFiles.Files[6], filepath.Clean("pkg/conflicted.go"), "", StatusConflicted, "") + }) + + t.Run("changed files truncation and snippet filters", func(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + lines := []string{"package pkg"} + for i := 0; i < maxChangedSnippetLinesPerFile+1; i++ { + lines = append(lines, fmt.Sprintf("line %d", i)) + } + content := strings.Join(lines, "\n") + for i := 0; i < 11; i++ { + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", fmt.Sprintf("file%d.txt", i)), content) + } + mustWriteRepositoryFile(t, filepath.Join(workdir, ".env"), "API_KEY=secret\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "bin.dat"), string([]byte{0x00, 0x01})) + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "large.txt"), strings.Repeat("x", maxRepositorySnippetFileBytes+1)) + + service := newRepositoryTestService(func(ctx context.Context, dir string, args ...string) (GitCommandOutput, error) { + if strings.Join(args, " ") != "status --porcelain=v1 -z --branch --untracked-files=normal" { + return GitCommandOutput{}, nil + } + records := []string{"## main", "?? .env", "?? pkg/bin.dat", "?? pkg/large.txt"} + for i := 0; i < 11; i++ { + records = append(records, "?? "+filepath.ToSlash(filepath.Join("pkg", fmt.Sprintf("file%d.txt", i)))) + } + return GitCommandOutput{Text: nulJoin(records...)}, nil + }) + + got, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{IncludeSnippets: true}) + if err != nil { + t.Fatalf("ChangedFiles() error = %v", err) + } + if !got.Truncated { + t.Fatalf("expected truncation after total snippet budget exhaustion") + } + for _, file := range got.Files[:3] { + if file.Snippet != "" { + t.Fatalf("expected filtered file to have empty snippet, got %+v", file) + } + } + if got.Files[len(got.Files)-1].Snippet != "" { + t.Fatalf("expected last snippet to be dropped after budget exhaustion, got %+v", got.Files[len(got.Files)-1]) + } + }) + + t.Run("retrieve path glob text symbol and guards", func(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "target.go"), "package pkg\n\ntype Widget struct{}\n\nfunc BuildWidget() Widget {\n\treturn Widget{}\n}\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "notes.txt"), "Widget appears here too\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "word.txt"), "Food Foo FooBar\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, ".env"), "SECRET=1\n") + + service := NewService() + + pathResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModePath, + Value: "pkg/target.go", + }) + if err != nil || len(pathResult.Hits) != 1 || pathResult.Hits[0].Kind != string(RetrievalModePath) { + t.Fatalf("unexpected path result: (%+v, %v)", pathResult, err) + } + + globResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModeGlob, + Value: "*.go", + }) + if err != nil || len(globResult.Hits) == 0 { + t.Fatalf("Retrieve(glob) = (%+v, %v), want hits", globResult, err) + } + + textResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModeText, + Value: "Widget", + }) + if err != nil || len(textResult.Hits) < 2 { + t.Fatalf("Retrieve(text) = (%+v, %v), want hits", textResult, err) + } + + symbolResult, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModeSymbol, + Value: "BuildWidget", + }) + if err != nil || len(symbolResult.Hits) != 1 || symbolResult.Hits[0].LineHint <= 0 { + t.Fatalf("Retrieve(symbol) = (%+v, %v), want one symbol hit", symbolResult, err) + } + + fallbackResult, err := service.retrieveBySymbol(context.Background(), workdir, workdir, RetrievalQuery{ + Mode: RetrievalModeSymbol, + Value: "Foo", + Limit: 5, + ContextLines: 1, + }) + if err != nil { + t.Fatalf("retrieveBySymbol fallback error = %v", err) + } + if len(fallbackResult.Hits) != 1 || !strings.Contains(fallbackResult.Hits[0].Snippet, "Food Foo FooBar") { + t.Fatalf("unexpected whole-word fallback hits: %+v", fallbackResult.Hits) + } + for _, hit := range fallbackResult.Hits { + if hit.Kind != string(RetrievalModeSymbol) { + t.Fatalf("expected symbol fallback kind, got %+v", hit) + } + } + + filteredPath, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModePath, + Value: ".env", + }) + if err != nil || len(filteredPath.Hits) != 0 { + t.Fatalf("expected sensitive file to be filtered, got (%+v, %v)", filteredPath, err) + } + + _, err = service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalMode("invalid"), + Value: "x", + }) + if !errors.Is(err, errInvalidMode) { + t.Fatalf("expected invalid mode error, got %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := service.Retrieve(ctx, workdir, RetrievalQuery{Mode: RetrievalModeText, Value: "Widget"}); !errors.Is(err, context.Canceled) { + t.Fatalf("expected canceled retrieve, got %v", err) + } + }) +} + +func TestRepositoryHelpersAndGitParsing(t *testing.T) { + t.Parallel() + + t.Run("git parsing helpers", func(t *testing.T) { + t.Parallel() + + if branch, ahead, behind := parseBranchLine(""); branch != "" || ahead != 0 || behind != 0 { + t.Fatalf("parseBranchLine(empty) = (%q,%d,%d)", branch, ahead, behind) + } + if branch, _, _ := parseBranchLine("HEAD (no branch)"); branch != "detached" { + t.Fatalf("parseBranchLine(detached) = %q", branch) + } + if branch, ahead, behind := parseBranchLine("main [ahead nope, behind 3]"); branch != "main" || ahead != 0 || behind != 3 { + t.Fatalf("parseBranchLine(invalid tracking) = (%q,%d,%d)", branch, ahead, behind) + } + if ahead, behind := parseTrackingCounters("main [ahead 2, weird, behind 1, ahead nope]"); ahead != 2 || behind != 1 { + t.Fatalf("parseTrackingCounters() = (%d,%d)", ahead, behind) + } + if got := splitNulRecords("a\x00b\x00\x00"); !slices.Equal(got, []string{"a", "b"}) { + t.Fatalf("splitNulRecords() = %#v", got) + } + + tests := []struct { + records []string + ok bool + consumed int + }{ + {records: nil, consumed: 1}, + {records: []string{"?? pkg/new.go"}, ok: true, consumed: 1}, + {records: []string{"R new.go", "old.go"}, ok: true, consumed: 2}, + {records: []string{"C copied.go", "source.go"}, ok: true, consumed: 2}, + {records: []string{"XY file.txt"}, consumed: 1}, + } + for _, tt := range tests { + _, consumed, ok := parseChangedRecord(tt.records) + if ok != tt.ok || consumed != tt.consumed { + t.Fatalf("parseChangedRecord(%v) = (ok=%v, consumed=%d)", tt.records, ok, consumed) + } + } + if normalizeStatus('U', 'A') != StatusConflicted || + normalizeStatus('R', ' ') != StatusRenamed || + normalizeStatus('C', ' ') != StatusCopied || + normalizeStatus('D', ' ') != StatusDeleted || + normalizeStatus('A', ' ') != StatusAdded || + normalizeStatus('M', ' ') != StatusModified || + normalizeStatus('X', 'Y') != "" { + t.Fatalf("normalizeStatus() mapping mismatch") + } + }) + + t.Run("path, snippet, workspace, and context helpers", func(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "a.go"), "package pkg\n\nconst Name = \"Widget\"\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "b.txt"), "Widget appears twice\nWidget\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "node_modules", "ignored.txt"), "ignored") + + if _, _, _, err := normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalModePath, Value: " "}); err == nil { + t.Fatal("expected empty query error") + } + if _, _, _, err := normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalMode("x"), Value: "a"}); !errors.Is(err, errInvalidMode) { + t.Fatalf("expected invalid mode error, got %v", err) + } + if _, _, _, err := normalizeRetrievalQuery(workdir, RetrievalQuery{Mode: RetrievalModePath, Value: "a", ScopeDir: "pkg/a.go"}); err == nil { + t.Fatal("expected scope is not directory error") + } + + root, scope, normalized, err := normalizeRetrievalQuery(workdir, RetrievalQuery{ + Mode: RetrievalModeText, + Value: " Widget ", + Limit: 999, + ContextLines: -1, + }) + if err != nil || root == "" || scope == "" { + t.Fatalf("normalizeRetrievalQuery() = (%q, %q, %+v, %v)", root, scope, normalized, err) + } + if normalized.Value != "Widget" || normalized.Limit != maxRetrievalLimit || normalized.ContextLines != defaultContextLines { + t.Fatalf("unexpected normalized query: %+v", normalized) + } + + lines := splitNonEmptyLines("a\r\n\n b \n\t\nc") + if !slices.Equal(lines, []string{"a", " b ", "c"}) { + t.Fatalf("splitNonEmptyLines() = %#v", lines) + } + if snippet := trimSnippetText("a\nb\nc", 2); !snippet.truncated || snippet.lines != 2 { + t.Fatalf("trimSnippetText() = %+v", snippet) + } + if text, changed := truncateSnippetLine(strings.Repeat("x", maxSnippetLineRunes+1), maxSnippetLineRunes); !changed || len([]rune(text)) != maxSnippetLineRunes { + t.Fatalf("truncateSnippetLine() = (%q, %v)", text, changed) + } + if snippet, hint := snippetAroundLine("a\nb\nc", 99, 1); hint != 3 || !strings.Contains(snippet, "c") { + t.Fatalf("snippetAroundLine() = (%q,%d)", snippet, hint) + } + + var visited []string + err = walkWorkspaceFiles(context.Background(), workdir, workdir, func(path string) error { + visited = append(visited, filepath.Base(path)) + return nil + }) + if err != nil { + t.Fatalf("walkWorkspaceFiles() error = %v", err) + } + if slices.Contains(visited, "ignored.txt") { + t.Fatalf("expected node_modules to be skipped, got %v", visited) + } + stopErr := errors.New("stop") + if err := walkWorkspaceFiles(context.Background(), workdir, workdir, func(path string) error { return stopErr }); !errors.Is(err, stopErr) { + t.Fatalf("expected callback error to bubble up, got %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := walkWorkspaceFiles(ctx, workdir, workdir, func(path string) error { return nil }); !errors.Is(err, context.Canceled) { + t.Fatalf("expected canceled walk, got %v", err) + } + + if normalizeLimit(0, 3, 10) != 3 || normalizeLimit(11, 3, 10) != 10 || normalizeLimit(4, 3, 10) != 4 { + t.Fatalf("normalizeLimit() mismatch") + } + if filepathSlashClean("a/b") != filepath.Clean(filepath.FromSlash("a/b")) { + t.Fatalf("filepathSlashClean() mismatch") + } + if minInt(1, 2) != 1 || minInt(3, 2) != 2 { + t.Fatalf("minInt() mismatch") + } + }) + + t.Run("run git command and utility helpers", func(t *testing.T) { + t.Parallel() + + if out, err := runGitCommand(context.Background(), t.TempDir(), GitCommandOptions{}, "--version"); err != nil || !strings.Contains(strings.ToLower(out.Text), "git version") { + t.Fatalf("runGitCommand(--version) = (%+v, %v)", out, err) + } + if _, err := runGitCommand(context.Background(), t.TempDir(), GitCommandOptions{}, "unknown-subcommand-for-test"); err == nil { + t.Fatal("expected invalid git subcommand to fail") + } + + if !isNotGitRepository("fatal: not a git repository", errors.New("x")) { + t.Fatal("expected not git repository detection") + } + if isNotGitRepository("", nil) { + t.Fatal("expected nil error to return false") + } + if !isContextError(context.Canceled) || !isContextError(context.DeadlineExceeded) || isContextError(errors.New("x")) { + t.Fatal("isContextError() mismatch") + } + + var buf gitOutputBuffer + if n, err := buf.Write([]byte("abc")); err != nil || n != 3 || buf.String() != "abc" { + t.Fatalf("gitOutputBuffer write = (%d, %v, %q)", n, err, buf.String()) + } + buf = gitOutputBuffer{maxBytes: 2} + if n, err := buf.Write([]byte("abcd")); err != nil || n != 4 || !buf.truncated || buf.String() != "ab" { + t.Fatalf("gitOutputBuffer limited write = (%d, %v, %q, truncated=%v)", n, err, buf.String(), buf.truncated) + } + buf = gitOutputBuffer{maxBytes: 2} + if _, err := buf.Write([]byte("ab")); err != nil { + t.Fatalf("gitOutputBuffer fill error = %v", err) + } + if n, err := buf.Write([]byte("c")); err != nil || n != 1 || !buf.truncated { + t.Fatalf("gitOutputBuffer overflow write = (%d, %v, truncated=%v)", n, err, buf.truncated) + } + root := t.TempDir() + mustWriteRepositoryFile(t, filepath.Join(root, ".git", "HEAD"), "ref: refs/heads/main\n") + if !hasGitMetadataAncestor(filepath.Join(root, ".git", "objects")) { + t.Fatal("expected .git ancestor detection") + } + if !isAmbiguousGitStatusOutsideRepo(t.TempDir(), "", errors.New("exit status 128")) { + t.Fatal("expected ambiguous outside-repo status to be treated as non-git") + } + if isAmbiguousGitStatusOutsideRepo(root, "", errors.New("exit status 128")) { + t.Fatal("expected git ancestor to disable ambiguous outside-repo fallback") + } + }) +} + +func TestRepositoryReadSearchAndServiceEntrypoints(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "readable.go"), "package pkg\n\nfunc Readable() {}\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "search.txt"), "alpha beta\nalpha\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "search2.txt"), "alpha again\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "defs.go"), "package pkg\n\nfunc BuildWidget(\n\tname string,\n) string {\n\treturn name\n}\n\ntype Widget struct{}\nconst WidgetName = \"x\"\nvar WidgetVar = 1\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "notes.py"), "def py_symbol():\n return 1\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, ".env"), "SECRET=1\n") + mustWriteRepositoryFile(t, filepath.Join(workdir, "pkg", "bin.dat"), string([]byte{0x00, 0x01, 0x02})) + + service := NewService() + + readResult, err := service.Read(context.Background(), workdir, "pkg/readable.go", ReadOptions{MaxBytes: 12}) + if err != nil { + t.Fatalf("Read() error = %v", err) + } + if readResult.Path != filepath.Clean("pkg/readable.go") || !readResult.Truncated || readResult.IsBinary { + t.Fatalf("unexpected read result: %+v", readResult) + } + + binaryRead, err := service.Read(context.Background(), workdir, "pkg/bin.dat", ReadOptions{}) + if err != nil { + t.Fatalf("Read(binary) error = %v", err) + } + if !binaryRead.IsBinary { + t.Fatalf("expected binary read result, got %+v", binaryRead) + } + + filteredRead, err := service.Read(context.Background(), workdir, ".env", ReadOptions{}) + if err != nil || filteredRead.Content != "" { + t.Fatalf("expected sensitive read to be filtered, got (%+v, %v)", filteredRead, err) + } + + textResult, err := service.SearchText(context.Background(), workdir, "alpha", SearchOptions{Limit: 1}) + if err != nil { + t.Fatalf("SearchText() error = %v", err) + } + if len(textResult.Hits) != 1 || !textResult.Truncated || textResult.TotalCount == 0 { + t.Fatalf("unexpected text search result: %+v", textResult) + } + + symbolResult, err := service.SearchSymbol(context.Background(), workdir, "BuildWidget", SearchOptions{Limit: 10}) + if err != nil { + t.Fatalf("SearchSymbol(go) error = %v", err) + } + if len(symbolResult.Hits) == 0 || symbolResult.Hits[0].Kind != "function" || !strings.Contains(symbolResult.Hits[0].Signature, "func BuildWidget") { + t.Fatalf("unexpected go symbol result: %+v", symbolResult) + } + + treeResult, err := service.SearchSymbol(context.Background(), workdir, "py_symbol", SearchOptions{Limit: 10}) + if err != nil { + t.Fatalf("SearchSymbol(tree-sitter) error = %v", err) + } + if len(treeResult.Hits) == 0 || treeResult.Hits[0].Kind == "" { + t.Fatalf("unexpected tree-sitter symbol result: %+v", treeResult) + } + + fallbackResult, err := service.SearchSymbol(context.Background(), workdir, "alpha", SearchOptions{Limit: 10}) + if err != nil { + t.Fatalf("SearchSymbol(fallback) error = %v", err) + } + if len(fallbackResult.Hits) == 0 || fallbackResult.Hits[0].Kind != "reference" { + t.Fatalf("unexpected fallback symbol result: %+v", fallbackResult) + } + + if got := extractGoSignature("func BuildWidget(\n\tname string,\n) string {\n\treturn name\n}\n", 1); !strings.Contains(got, "name string") { + t.Fatalf("extractGoSignature(multiline) = %q", got) + } + if got := classifyGoSignature("func (*Svc).BuildWidget() {}"); got != "method" { + t.Fatalf("classifyGoSignature(method) = %q", got) + } + if got := classifyGoSignature("const Widget = 1"); got != "constant" { + t.Fatalf("classifyGoSignature(const) = %q", got) + } + if got := classifyGoSignature("var Widget = 1"); got != "variable" { + t.Fatalf("classifyGoSignature(var) = %q", got) + } + if got := classifyGoSignature("type Widget struct{}"); got != "type" { + t.Fatalf("classifyGoSignature(type) = %q", got) + } + if got := classifyGoSignature("???"); got != "unknown" { + t.Fatalf("classifyGoSignature(unknown) = %q", got) + } + + root, scope, err := resolveSearchScope(workdir, "pkg") + if err != nil || root == "" || scope == "" { + t.Fatalf("resolveSearchScope() = (%q, %q, %v)", root, scope, err) + } + if _, _, err := resolveSearchScope(workdir, "../bad"); err == nil { + t.Fatal("expected invalid search scope error") + } +} + +func newRepositoryTestService(runner func(ctx context.Context, workdir string, args ...string) (GitCommandOutput, error)) *Service { + return &Service{ + gitRunner: func(ctx context.Context, workdir string, opts GitCommandOptions, args ...string) (GitCommandOutput, error) { + return runner(ctx, workdir, args...) + }, + readFile: readFile, + } +} + +func assertChangedRepositoryFile(t *testing.T, file ChangedFile, path string, oldPath string, status ChangedFileStatus, snippetContains string) { + t.Helper() + if file.Path != path || file.OldPath != oldPath || file.Status != status { + t.Fatalf("unexpected changed file: %+v", file) + } + if snippetContains == "" { + if file.Snippet != "" { + t.Fatalf("expected empty snippet, got %q", file.Snippet) + } + return + } + if !strings.Contains(file.Snippet, snippetContains) { + t.Fatalf("expected snippet to contain %q, got %q", snippetContains, file.Snippet) + } +} + +func mustWriteRepositoryFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } +} + +func nulJoin(records ...string) string { + if len(records) == 0 { + return "" + } + return strings.Join(records, "\x00") + "\x00" +} diff --git a/internal/context/repository/retrieve.go b/internal/repository/retrieve.go similarity index 62% rename from internal/context/repository/retrieve.go rename to internal/repository/retrieve.go index ff350625..503c30e7 100644 --- a/internal/context/repository/retrieve.go +++ b/internal/repository/retrieve.go @@ -15,15 +15,14 @@ import ( ) const ( - defaultRetrievalLimit = 20 - maxRetrievalLimit = 50 - defaultContextLines = 3 - maxContextLines = 8 - maxSnippetLines = 20 maxRepositorySnippetFileBytes = 256 * 1024 binaryProbePrefixSize = 1024 + defaultReadMaxBytes = 256 * 1024 + maxSignatureLength = 512 ) +var errRetrievalLimitReached = errors.New("repository: retrieval limit reached") + var blockedRepositorySnippetExtensions = map[string]struct{}{ ".p8": {}, ".key": {}, @@ -89,7 +88,301 @@ var blockedRepositorySnippetConfigKeywords = []string{ "tokens", } -var errRetrievalLimitReached = errors.New("repository: retrieval limit reached") +// Read 按路径读取目标文件的受限内容(codebase_read)。 +func (s *Service) Read(ctx context.Context, workdir string, path string, opts ReadOptions) (ReadResult, error) { + if err := ctx.Err(); err != nil { + return ReadResult{}, err + } + target, info, allowed, err := resolveRepositorySnippetFileFromRoot(workdir, path) + if err != nil { + return ReadResult{}, err + } + if !allowed { + return ReadResult{}, nil + } + content, err := s.readFile(target) + if err != nil { + if os.IsNotExist(err) { + return ReadResult{}, nil + } + return ReadResult{}, err + } + isBinary := isBinaryContent(content) + maxBytes := opts.MaxBytes + if maxBytes <= 0 { + maxBytes = defaultReadMaxBytes + } + truncated := false + if len(content) > maxBytes { + content = content[:maxBytes] + truncated = true + } + rel, _ := filepath.Rel(workdir, target) + return ReadResult{ + Path: filepath.Clean(rel), + Content: string(content), + Truncated: truncated, + IsBinary: isBinary, + Size: info.Size(), + }, nil +} + +// SearchText 扫描工作区文本文件并返回稳定排序的关键字命中(硬约束:不返回代码内容)。 +func (s *Service) SearchText(ctx context.Context, workdir string, query string, opts SearchOptions) (TextSearchResult, error) { + if err := ctx.Err(); err != nil { + return TextSearchResult{}, err + } + root, scope, err := resolveSearchScope(workdir, opts.ScopeDir) + if err != nil { + return TextSearchResult{}, err + } + + effectiveLimit := opts.Limit + 1 + if effectiveLimit <= 1 { + effectiveLimit = defaultRetrievalLimit + 1 + } + + var wholeWordRe *regexp.Regexp + if opts.WholeWord { + wholeWordRe = regexp.MustCompile(`\b` + regexp.QuoteMeta(query) + `\b`) + } + + hits := make([]TextSearchHit, 0, effectiveLimit) + truncated := false + + err = walkWorkspaceFiles(ctx, root, scope, func(path string) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if len(hits) >= effectiveLimit { + return errRetrievalLimitReached + } + content, ok := s.readRetrievalText(root, path) + if !ok { + return nil + } + lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n") + matchCount := 0 + firstLine := 0 + for index, line := range lines { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + matched := strings.Contains(line, query) + if wholeWordRe != nil { + matched = wholeWordRe.MatchString(line) + } + if matched { + matchCount++ + if firstLine == 0 { + firstLine = index + 1 + } + } + } + if matchCount == 0 { + return nil + } + rel, _ := filepath.Rel(root, path) + hits = append(hits, TextSearchHit{ + Path: filepath.Clean(rel), + LineHint: firstLine, + MatchCount: matchCount, + }) + if len(hits) >= effectiveLimit { + return errRetrievalLimitReached + } + return nil + }) + if err != nil { + if errors.Is(err, errRetrievalLimitReached) { + err = nil + } + } + if err != nil { + return TextSearchResult{}, err + } + totalCount := len(hits) + if len(hits) > opts.Limit && opts.Limit > 0 { + hits = hits[:opts.Limit] + truncated = true + } + + sort.Slice(hits, func(i, j int) bool { + if hits[i].Path == hits[j].Path { + return hits[i].LineHint < hits[j].LineHint + } + return hits[i].Path < hits[j].Path + }) + return TextSearchResult{Hits: hits, Truncated: truncated, TotalCount: totalCount}, nil +} + +// SearchSymbol 先做 Go 定义检索,再在无定义命中时回退到 whole-word 文本检索(硬约束:仅返回签名)。 +func (s *Service) SearchSymbol(ctx context.Context, workdir string, symbol string, opts SearchOptions) (SymbolSearchResult, error) { + if err := ctx.Err(); err != nil { + return SymbolSearchResult{}, err + } + root, scope, err := resolveSearchScope(workdir, opts.ScopeDir) + if err != nil { + return SymbolSearchResult{}, err + } + + effectiveLimit := opts.Limit + 1 + if effectiveLimit <= 1 { + effectiveLimit = defaultRetrievalLimit + 1 + } + + hits := make([]SymbolSearchHit, 0, effectiveLimit) + truncated := false + + err = walkWorkspaceFiles(ctx, root, scope, func(path string) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if len(hits) >= effectiveLimit { + return errRetrievalLimitReached + } + if filepath.Ext(path) != ".go" { + return nil + } + content, ok := s.readRetrievalText(root, path) + if !ok { + return nil + } + lineNumbers := findGoSymbolDefinitions(content, symbol) + for _, lineNumber := range lineNumbers { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if len(hits) >= effectiveLimit { + break + } + sig := extractGoSignature(content, lineNumber) + kind := classifyGoSignature(sig) + rel, _ := filepath.Rel(root, path) + hits = append(hits, SymbolSearchHit{ + Path: filepath.Clean(rel), + LineHint: lineNumber, + Kind: kind, + Signature: sig, + }) + if len(hits) >= effectiveLimit { + return errRetrievalLimitReached + } + } + return nil + }) + if err != nil { + if errors.Is(err, errRetrievalLimitReached) { + err = nil + } + } + if err != nil { + return SymbolSearchResult{}, err + } + totalCount := len(hits) + if len(hits) > opts.Limit && opts.Limit > 0 { + hits = hits[:opts.Limit] + truncated = true + } + if len(hits) > 0 { + sort.Slice(hits, func(i, j int) bool { + if hits[i].Path == hits[j].Path { + return hits[i].LineHint < hits[j].LineHint + } + return hits[i].Path < hits[j].Path + }) + return SymbolSearchResult{Hits: hits, Truncated: truncated, TotalCount: totalCount}, nil + } + + // Second pass: Tree-sitter for non-Go files. + if err := s.treesitterIndex.EnsureBuilt(ctx, root, scope, s.readFile); err != nil { + return SymbolSearchResult{}, err + } + tsHits := s.treesitterIndex.Search(symbol, effectiveLimit) + for _, h := range tsHits { + if ctx.Err() != nil { + return SymbolSearchResult{}, ctx.Err() + } + if len(hits) >= effectiveLimit { + break + } + hits = append(hits, h) + } + totalCount = len(hits) + if len(hits) > opts.Limit && opts.Limit > 0 { + hits = hits[:opts.Limit] + truncated = true + } + if len(hits) > 0 { + sort.Slice(hits, func(i, j int) bool { + if hits[i].Path == hits[j].Path { + return hits[i].LineHint < hits[j].LineHint + } + return hits[i].Path < hits[j].Path + }) + return SymbolSearchResult{Hits: hits, Truncated: truncated, TotalCount: totalCount}, nil + } + + // Fallback: whole-word text search (all file types). + fallbackOpts := opts + fallbackOpts.WholeWord = true + textResult, err := s.SearchText(ctx, workdir, symbol, fallbackOpts) + if err != nil { + return SymbolSearchResult{}, err + } + for _, th := range textResult.Hits { + hits = append(hits, SymbolSearchHit{ + Path: th.Path, + LineHint: th.LineHint, + Kind: "reference", + }) + } + return SymbolSearchResult{Hits: hits, Truncated: textResult.Truncated, TotalCount: len(hits)}, nil +} + +// extractGoSignature 从指定行提取声明签名,最长 maxSignatureLength 字符。 +func extractGoSignature(content string, lineNumber int) string { + lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n") + if lineNumber <= 0 || lineNumber > len(lines) { + return "" + } + idx := lineNumber - 1 + sig := strings.TrimSpace(lines[idx]) + // 尝试拼接多行函数签名(以 ( 开头但未以 ) 结尾时)。 + if strings.HasPrefix(sig, "func") && strings.Contains(sig, "(") && !strings.Contains(sig, ")") { + for i := idx + 1; i < len(lines) && len(sig) < maxSignatureLength*2; i++ { + sig += " " + strings.TrimSpace(lines[i]) + if strings.Contains(lines[i], ")") { + break + } + } + } + if len(sig) > maxSignatureLength { + sig = sig[:maxSignatureLength] + } + return sig +} + +// classifyGoSignature 根据签名前缀推断符号类别。 +func classifyGoSignature(sig string) string { + sig = strings.TrimSpace(sig) + switch { + case strings.HasPrefix(sig, "func "): + if strings.Contains(sig, ")") && strings.Contains(sig, ".") { + // func (r *Receiver) Method(...) + return "method" + } + return "function" + case strings.HasPrefix(sig, "type "): + return "type" + case strings.HasPrefix(sig, "const "): + return "constant" + case strings.HasPrefix(sig, "var "): + return "variable" + default: + return "unknown" + } +} // retrieveByPath 按路径读取目标文件的受限片段。 func (s *Service) retrieveByPath(ctx context.Context, root string, query RetrievalQuery) (RetrievalResult, error) { @@ -412,11 +705,7 @@ func buildRetrievalHit( }, nil } -func readFile(path string) ([]byte, error) { - return os.ReadFile(path) -} - -// allowRepositorySnippetByPath 基于路径检查文件是否允许进入 repository 片段。 +// resolveRepositorySnippetFile 基于路径检查文件是否允许进入 repository 片段。 func resolveRepositorySnippetFile(workdir string, path string) (string, os.FileInfo, bool, error) { root, _, err := security.ResolveWorkspacePath(workdir, ".") if err != nil { @@ -546,3 +835,19 @@ func isBinaryContent(content []byte) bool { } return false } + +// resolveSearchScope 解析搜索范围,返回 (root, scope, error)。 +func resolveSearchScope(workdir string, scopeDir string) (string, string, error) { + root, _, err := security.ResolveWorkspacePath(workdir, ".") + if err != nil { + return "", "", err + } + if strings.TrimSpace(scopeDir) == "" { + return root, root, nil + } + scope, err := resolveScopeDir(root, scopeDir) + if err != nil { + return "", "", err + } + return root, scope, nil +} diff --git a/internal/repository/treesitter.go b/internal/repository/treesitter.go new file mode 100644 index 00000000..0b3d115d --- /dev/null +++ b/internal/repository/treesitter.go @@ -0,0 +1,196 @@ +package repository + +import ( + "context" + "errors" + "path/filepath" + "strings" + + "github.com/odvcencio/gotreesitter" + "github.com/odvcencio/gotreesitter/grammars" +) + +// searchSymbolsWithTreeSitter 使用 Tree-sitter 在非 Go 文件中搜索符号定义。 +// 它遍历工作区,对支持的文件类型执行 tags.scm query,匹配 name capture。 +func searchSymbolsWithTreeSitter( + ctx context.Context, + root string, + scope string, + symbol string, + readFile FileReader, + effectiveLimit int, +) ([]SymbolSearchHit, error) { + hits := make([]SymbolSearchHit, 0, effectiveLimit) + + err := walkWorkspaceFiles(ctx, root, scope, func(path string) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if len(hits) >= effectiveLimit { + return errRetrievalLimitReached + } + // Skip Go files — handled by Go AST path. + if filepath.Ext(path) == ".go" { + return nil + } + + entry := grammars.DetectLanguage(filepath.Base(path)) + if entry == nil { + return nil + } + + content, ok := readRetrievalTextWithReader(root, path, readFile) + if !ok { + return nil + } + + // Fast path: skip files that do not contain the symbol text at all. + if !strings.Contains(content, symbol) { + return nil + } + + lang := entry.Language() + tagsQuery := grammars.ResolveTagsQuery(*entry) + if tagsQuery == "" { + return nil + } + + parser := gotreesitter.NewParser(lang) + tree, parseErr := parser.Parse([]byte(content)) + if parseErr != nil { + return nil + } + rootNode := tree.RootNode() + + query, queryErr := gotreesitter.NewQuery(tagsQuery, lang) + if queryErr != nil { + return nil + } + + cursor := query.Exec(rootNode, lang, []byte(content)) + src := []byte(content) + + for { + match, ok := cursor.NextMatch() + if !ok { + break + } + if len(hits) >= effectiveLimit { + return errRetrievalLimitReached + } + + var defNode *gotreesitter.Node + var defKind string + var nameText string + var nameLine int + + for _, cap := range match.Captures { + capName := cap.Name + if capName == "name" { + nameText = strings.TrimSpace(cap.Text(src)) + nameLine = int(cap.Node.StartPoint().Row) + 1 + } + if strings.HasPrefix(capName, "definition.") { + defNode = cap.Node + defKind = captureNameToKind(capName) + } + } + + if nameText != symbol { + continue + } + + sig := "" + if defNode != nil { + sig = extractTreeSitterSignature(src, defNode) + } + if sig == "" && nameLine > 0 { + sig = extractLineSignature(content, nameLine) + } + + rel, _ := filepath.Rel(root, path) + hits = append(hits, SymbolSearchHit{ + Path: filepath.Clean(rel), + LineHint: nameLine, + Kind: defKind, + Signature: sig, + }) + } + return nil + }) + + if err != nil { + if errors.Is(err, errRetrievalLimitReached) { + err = nil + } + } + if err != nil { + return nil, err + } + return hits, nil +} + +// captureNameToKind 将 tags.scm capture 名称映射到统一的符号类别。 +func captureNameToKind(capName string) string { + switch capName { + case "definition.function": + return "function" + case "definition.method": + return "method" + case "definition.class": + return "class" + case "definition.type": + return "type" + case "definition.variable": + return "variable" + case "definition.interface": + return "interface" + case "definition.constant": + return "constant" + } + return "unknown" +} + +// extractTreeSitterSignature 从定义节点提取签名,限制在 maxSignatureLength 内。 +func extractTreeSitterSignature(src []byte, node *gotreesitter.Node) string { + text := node.Text(src) + lines := strings.Split(strings.ReplaceAll(text, "\r\n", "\n"), "\n") + if len(lines) == 0 { + return "" + } + // Take the first line and trim trailing brace body start. + sig := strings.TrimSpace(lines[0]) + if idx := strings.Index(sig, "{"); idx >= 0 { + sig = strings.TrimSpace(sig[:idx]) + } + if len(sig) > maxSignatureLength { + sig = sig[:maxSignatureLength] + } + return sig +} + +// extractLineSignature 从文件内容的指定行提取原始文本作为签名。 +func extractLineSignature(content string, lineNumber int) string { + lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n") + if lineNumber <= 0 || lineNumber > len(lines) { + return "" + } + sig := strings.TrimSpace(lines[lineNumber-1]) + if len(sig) > maxSignatureLength { + sig = sig[:maxSignatureLength] + } + return sig +} + +// readRetrievalTextWithReader 使用给定的 reader 读取检索候选文件。 +func readRetrievalTextWithReader(root string, path string, readFile FileReader) (string, bool) { + target, _, allowed, err := resolveRepositorySnippetFileFromRoot(root, path) + if err != nil || !allowed { + return "", false + } + content, err := readFile(target) + if err != nil || isBinaryContent(content) { + return "", false + } + return string(content), true +} diff --git a/internal/repository/treesitter_helpers_test.go b/internal/repository/treesitter_helpers_test.go new file mode 100644 index 00000000..96f6a605 --- /dev/null +++ b/internal/repository/treesitter_helpers_test.go @@ -0,0 +1,141 @@ +package repository + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/odvcencio/gotreesitter" + "github.com/odvcencio/gotreesitter/grammars" +) + +func TestTreeSitterHelpersAndIndexerBranches(t *testing.T) { + t.Parallel() + + t.Run("helper functions", func(t *testing.T) { + t.Parallel() + + if captureNameToKind("definition.function") != "function" || + captureNameToKind("definition.method") != "method" || + captureNameToKind("definition.class") != "class" || + captureNameToKind("definition.type") != "type" || + captureNameToKind("definition.variable") != "variable" || + captureNameToKind("definition.interface") != "interface" || + captureNameToKind("definition.constant") != "constant" || + captureNameToKind("other") != "unknown" { + t.Fatal("captureNameToKind() mismatch") + } + + entry := grammars.DetectLanguageByName("python") + if entry == nil { + t.Fatal("expected python grammar entry") + } + lang := entry.Language() + parser := gotreesitter.NewParser(lang) + tree, err := parser.Parse([]byte("def hello(name):\n return name\n")) + if err != nil { + t.Fatalf("parse helper source: %v", err) + } + if got := extractTreeSitterSignature([]byte("def hello(name):\n return name\n"), tree.RootNode().Children()[0]); got != "def hello(name):" { + t.Fatalf("extractTreeSitterSignature() = %q", got) + } + oversized := "def " + strings.Repeat("a", maxSignatureLength+20) + "():\n pass\n" + tree, err = parser.Parse([]byte(oversized)) + if err != nil { + t.Fatalf("parse oversized helper source: %v", err) + } + if got := extractTreeSitterSignature([]byte(oversized), tree.RootNode().Children()[0]); len(got) != maxSignatureLength { + t.Fatalf("expected signature truncation, got len=%d", len(got)) + } + if got := extractLineSignature("a\nb\nc", 2); got != "b" { + t.Fatalf("extractLineSignature() = %q", got) + } + if got := extractLineSignature("a", 9); got != "" { + t.Fatalf("expected empty out-of-range signature, got %q", got) + } + }) + + t.Run("read helper, search helper, and index lifecycle", func(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + mustWriteRepositoryFile(t, filepath.Join(workspace, "main.py"), "def hello(name):\n return name\n") + mustWriteRepositoryFile(t, filepath.Join(workspace, "main.go"), "package main\nfunc SkipMe() {}\n") + mustWriteRepositoryFile(t, filepath.Join(workspace, "plain.txt"), "hello only text\n") + + content, ok := readRetrievalTextWithReader(workspace, filepath.Join(workspace, "main.py"), readFile) + if !ok || !strings.Contains(content, "hello") { + t.Fatalf("readRetrievalTextWithReader() = (%q, %v)", content, ok) + } + if _, ok := readRetrievalTextWithReader(workspace, filepath.Join(workspace, "missing.py"), readFile); ok { + t.Fatal("expected missing file to be skipped") + } + + hits, err := searchSymbolsWithTreeSitter(context.Background(), workspace, workspace, "hello", readFile, 10) + if err != nil || len(hits) == 0 { + t.Fatalf("searchSymbolsWithTreeSitter() = (%+v, %v)", hits, err) + } + if hits[0].Kind == "" || !strings.Contains(hits[0].Signature, "def hello") { + t.Fatalf("unexpected tree-sitter hit: %+v", hits[0]) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt() error = %v", err) + } + if !idx.isBuilt() || idx.getRoot() != workspace { + t.Fatalf("unexpected index state after build") + } + if got := idx.Search("hello", 1); len(got) != 1 { + t.Fatalf("expected limited search result, got %+v", got) + } + idx.Close() + if idx.isBuilt() || len(idx.Search("hello", 10)) != 0 { + t.Fatalf("expected close to reset index state") + } + }) + + t.Run("refresh replace and delete branches", func(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + pyPath := filepath.Join(workspace, "mod.py") + mustWriteRepositoryFile(t, pyPath, "def before():\n pass\n") + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt() error = %v", err) + } + if len(idx.Search("before", 10)) == 0 { + t.Fatal("expected initial symbol in index") + } + + if err := os.WriteFile(pyPath, []byte("def after():\n pass\n"), 0o644); err != nil { + t.Fatalf("rewrite file: %v", err) + } + if err := idx.Refresh(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("Refresh() error = %v", err) + } + if len(idx.Search("before", 10)) != 0 || len(idx.Search("after", 10)) == 0 { + t.Fatalf("expected refresh to replace indexed entries") + } + + goPath := filepath.Join(workspace, "skip.go") + mustWriteRepositoryFile(t, goPath, "package main\nfunc Ignore(){}\n") + if err := idx.replaceFile(context.Background(), workspace, goPath, readFile); err != nil { + t.Fatalf("replaceFile(go) error = %v", err) + } + + if err := os.Remove(pyPath); err != nil { + t.Fatalf("remove file: %v", err) + } + if err := idx.Refresh(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("Refresh(delete) error = %v", err) + } + if len(idx.Search("after", 10)) != 0 { + t.Fatalf("expected deleted file entries to be removed") + } + }) +} diff --git a/internal/repository/treesitter_index.go b/internal/repository/treesitter_index.go new file mode 100644 index 00000000..1cc1d118 --- /dev/null +++ b/internal/repository/treesitter_index.go @@ -0,0 +1,378 @@ +package repository + +import ( + "context" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/odvcencio/gotreesitter" + "github.com/odvcencio/gotreesitter/grammars" +) + +// tsIndexEntry 是倒排索引中的单条符号位置。 +type tsIndexEntry struct { + Name string // 符号名称(原始大小写,用于输出) + Path string + LineHint int + Kind string + Signature string +} + +// tsFileMeta 记录已索引文件的 mtime 和大小,用于增量更新检测。 +type tsFileMeta struct { + modTime time.Time + size int64 +} + +// TreeSitterIndexer 管理内存中的跨语言符号倒排索引。 +// 线程安全,支持惰性初始化 + 增量更新。 +type TreeSitterIndexer struct { + mu sync.RWMutex + built bool + root string + symbols map[string][]tsIndexEntry // lower(symbol) → entries + fileMeta map[string]tsFileMeta // path → meta +} + +// NewTreeSitterIndexer 返回一个惰性初始化的 Tree-sitter 索引器。 +func NewTreeSitterIndexer() *TreeSitterIndexer { + return &TreeSitterIndexer{ + symbols: make(map[string][]tsIndexEntry), + fileMeta: make(map[string]tsFileMeta), + } +} + +// Search 从索引中查询符号,返回精确匹配和前缀匹配的结果。 +// 返回的 hits 按 path + line_hint 排序。 +func (idx *TreeSitterIndexer) Search(name string, limit int) []SymbolSearchHit { + idx.mu.RLock() + defer idx.mu.RUnlock() + + if !idx.built { + return nil + } + + key := strings.ToLower(strings.TrimSpace(name)) + if key == "" { + return nil + } + + // 精确匹配 + exact := idx.symbols[key] + // 前缀匹配(如 "Hello" 匹配 "HelloWorld") + var prefix []tsIndexEntry + for k, entries := range idx.symbols { + if k != key && strings.HasPrefix(k, key) { + prefix = append(prefix, entries...) + } + } + + total := append(exact, prefix...) + if len(total) == 0 { + return nil + } + if limit > 0 && len(total) > limit { + total = total[:limit] + } + + hits := make([]SymbolSearchHit, len(total)) + for i, e := range total { + hits[i] = SymbolSearchHit{ + Path: e.Path, + LineHint: e.LineHint, + Kind: e.Kind, + Signature: e.Signature, + } + } + sort.Slice(hits, func(i, j int) bool { + if hits[i].Path == hits[j].Path { + return hits[i].LineHint < hits[j].LineHint + } + return hits[i].Path < hits[j].Path + }) + return hits +} + +// EnsureBuilt 惰性初始化索引:扫描工作区,解析所有非 Go 文件,构建倒排索引。 +func (idx *TreeSitterIndexer) EnsureBuilt(ctx context.Context, root string, scope string, readFile FileReader) error { + idx.mu.Lock() + defer idx.mu.Unlock() + + if idx.built && idx.root == root { + return nil + } + + idx.root = root + idx.symbols = make(map[string][]tsIndexEntry) + idx.fileMeta = make(map[string]tsFileMeta) + + walkErr := walkWorkspaceFiles(ctx, root, scope, func(absPath string) error { + return idx.indexFile(ctx, root, absPath, readFile) + }) + idx.built = true + if walkErr != nil { + return walkErr + } + return nil +} + +// Refresh 使用文件 mtime 检测变更并增量更新索引。 +func (idx *TreeSitterIndexer) Refresh(ctx context.Context, root string, scope string, readFile FileReader) error { + if !idx.isBuilt() || idx.getRoot() != root { + return idx.EnsureBuilt(ctx, root, scope, readFile) + } + + currentMeta := make(map[string]tsFileMeta) + var changedFiles []string + var deletedFiles []string + var mu sync.Mutex + + walkErr := walkWorkspaceFiles(ctx, root, scope, func(absPath string) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + rel, relErr := filepath.Rel(root, absPath) + if relErr != nil { + return nil + } + cleanRel := filepath.ToSlash(rel) + + info, statErr := os.Stat(absPath) + if statErr != nil { + return nil + } + meta := tsFileMeta{modTime: info.ModTime(), size: info.Size()} + + mu.Lock() + currentMeta[cleanRel] = meta + mu.Unlock() + + idx.mu.RLock() + oldMeta, exists := idx.fileMeta[cleanRel] + idx.mu.RUnlock() + + if !exists || !meta.modTime.Equal(oldMeta.modTime) || meta.size != oldMeta.size { + mu.Lock() + changedFiles = append(changedFiles, absPath) + mu.Unlock() + } + return nil + }) + if walkErr != nil { + return walkErr + } + + idx.mu.RLock() + for path := range idx.fileMeta { + if _, exists := currentMeta[path]; !exists { + deletedFiles = append(deletedFiles, path) + } + } + idx.mu.RUnlock() + + // 重新索引变更文件 + for _, absPath := range changedFiles { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + idx.replaceFile(ctx, root, absPath, readFile) + } + + // 删除已移除文件的索引条目 + idx.mu.Lock() + for _, cleanRel := range deletedFiles { + idx.deleteFileEntries(cleanRel) + } + idx.mu.Unlock() + + return nil +} + +// Close 释放索引持有的内存。 +func (idx *TreeSitterIndexer) Close() { + idx.mu.Lock() + defer idx.mu.Unlock() + idx.symbols = make(map[string][]tsIndexEntry) + idx.fileMeta = make(map[string]tsFileMeta) + idx.built = false +} + +func (idx *TreeSitterIndexer) isBuilt() bool { + idx.mu.RLock() + defer idx.mu.RUnlock() + return idx.built +} + +func (idx *TreeSitterIndexer) getRoot() string { + idx.mu.RLock() + defer idx.mu.RUnlock() + return idx.root +} + +// indexFile 解析单个文件的符号并添加到索引。 +func (idx *TreeSitterIndexer) indexFile(ctx context.Context, root string, absPath string, readFile FileReader) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if filepath.Ext(absPath) == ".go" { + return nil + } + + rel, relErr := filepath.Rel(root, absPath) + if relErr != nil { + return nil + } + cleanRel := filepath.ToSlash(rel) + + entries, meta, ok := idx.parseFile(root, absPath, readFile) + if !ok { + return nil + } + + idx.fileMeta[cleanRel] = meta + for _, e := range entries { + key := strings.ToLower(strings.TrimSpace(e.Name)) + idx.symbols[key] = append(idx.symbols[key], e) + } + return nil +} + +// replaceFile 替换已有文件的索引条目(先删后加)。 +func (idx *TreeSitterIndexer) replaceFile(ctx context.Context, root string, absPath string, readFile FileReader) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + rel, relErr := filepath.Rel(root, absPath) + if relErr != nil { + return nil + } + cleanRel := filepath.ToSlash(rel) + + if filepath.Ext(absPath) == ".go" { + idx.mu.Lock() + idx.deleteFileEntries(cleanRel) + idx.mu.Unlock() + return nil + } + + entries, meta, ok := idx.parseFile(root, absPath, readFile) + idx.mu.Lock() + idx.deleteFileEntries(cleanRel) + if ok { + idx.fileMeta[cleanRel] = meta + for _, e := range entries { + key := strings.ToLower(strings.TrimSpace(e.Name)) + idx.symbols[key] = append(idx.symbols[key], e) + } + } + idx.mu.Unlock() + return nil +} + +// deleteFileEntries 从索引中删除指定文件的全部条目。 +func (idx *TreeSitterIndexer) deleteFileEntries(cleanRel string) { + delete(idx.fileMeta, cleanRel) + for key, entries := range idx.symbols { + filtered := entries[:0] + for _, e := range entries { + if e.Path != cleanRel { + filtered = append(filtered, e) + } + } + if len(filtered) == 0 { + delete(idx.symbols, key) + } else { + idx.symbols[key] = filtered + } + } +} + +// parseFile 使用 Tree-sitter 解析一个文件,返回所有符号条目 + 文件元信息。 +func (idx *TreeSitterIndexer) parseFile(root string, absPath string, readFile FileReader) ([]tsIndexEntry, tsFileMeta, bool) { + content, ok := readRetrievalTextWithReader(root, absPath, readFile) + if !ok { + return nil, tsFileMeta{}, false + } + + info, statErr := os.Stat(absPath) + if statErr != nil { + return nil, tsFileMeta{}, false + } + meta := tsFileMeta{modTime: info.ModTime(), size: info.Size()} + + entry := grammars.DetectLanguage(filepath.Base(absPath)) + if entry == nil { + return nil, meta, false + } + lang := entry.Language() + tagsQuery := grammars.ResolveTagsQuery(*entry) + if tagsQuery == "" { + return nil, meta, false + } + + parser := gotreesitter.NewParser(lang) + tree, parseErr := parser.Parse([]byte(content)) + if parseErr != nil { + return nil, meta, false + } + rootNode := tree.RootNode() + + query, queryErr := gotreesitter.NewQuery(tagsQuery, lang) + if queryErr != nil { + return nil, meta, false + } + + cursor := query.Exec(rootNode, lang, []byte(content)) + src := []byte(content) + var entries []tsIndexEntry + + for { + match, ok := cursor.NextMatch() + if !ok { + break + } + var defNode *gotreesitter.Node + var defKind string + var nameText string + var nameLine int + + for _, cap := range match.Captures { + capName := cap.Name + if capName == "name" { + nameText = strings.TrimSpace(cap.Text(src)) + nameLine = int(cap.Node.StartPoint().Row) + 1 + } + if strings.HasPrefix(capName, "definition.") { + defNode = cap.Node + defKind = captureNameToKind(capName) + } + } + + if nameText == "" { + continue + } + + sig := "" + if defNode != nil { + sig = extractTreeSitterSignature(src, defNode) + } + if sig == "" && nameLine > 0 { + sig = extractLineSignature(content, nameLine) + } + + rel, _ := filepath.Rel(root, absPath) + entries = append(entries, tsIndexEntry{ + Name: nameText, + Path: filepath.Clean(rel), + LineHint: nameLine, + Kind: defKind, + Signature: sig, + }) + } + + return entries, meta, true +} diff --git a/internal/repository/treesitter_test.go b/internal/repository/treesitter_test.go new file mode 100644 index 00000000..4f6ebf35 --- /dev/null +++ b/internal/repository/treesitter_test.go @@ -0,0 +1,279 @@ +package repository + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestTreeSitterIndexerPythonDefinitions(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `def hello(name): + return "hi " + name + +class MyClass: + def method(self): + pass +` + if err := os.WriteFile(filepath.Join(workspace, "main.py"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + hits := idx.Search("hello", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'hello'") + } + if hits[0].Kind != "function" { + t.Fatalf("expected function kind, got %q", hits[0].Kind) + } + if !strings.Contains(hits[0].Signature, "def hello") { + t.Fatalf("expected signature containing 'def hello', got %q", hits[0].Signature) + } + + hits = idx.Search("MyClass", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'MyClass'") + } + if hits[0].Kind != "class" { + t.Fatalf("expected class kind, got %q", hits[0].Kind) + } + + hits = idx.Search("method", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'method'") + } + // Python Tree-sitter grammar tags query categorizes def inside class as + // definition.function, not definition.method. This is a grammar-level + // limitation — the index correctly reflects what the grammar reports. + if hits[0].Kind != "function" && hits[0].Kind != "method" { + t.Fatalf("expected function/method kind, got %q", hits[0].Kind) + } +} + +func TestTreeSitterIndexerTypeScriptDefinitions(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `function greet(name: string): string { + return "hello " + name; +} + +class Person { + constructor(public name: string) {} + sayHi(): void {} +} +` + if err := os.WriteFile(filepath.Join(workspace, "main.ts"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + hits := idx.Search("greet", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'greet'") + } + if hits[0].Kind != "function" { + t.Fatalf("expected function kind, got %q", hits[0].Kind) + } + + hits = idx.Search("sayHi", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'sayHi'") + } + if hits[0].Kind != "method" { + t.Fatalf("expected method kind, got %q", hits[0].Kind) + } +} + +func TestTreeSitterIndexerJavaDefinitions(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `public class Main { + public String greet(String name) { + return "hello " + name; + } +} +` + if err := os.WriteFile(filepath.Join(workspace, "Main.java"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + hits := idx.Search("Main", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'Main'") + } + if hits[0].Kind != "class" { + t.Fatalf("expected class kind, got %q", hits[0].Kind) + } +} + +func TestTreeSitterIndexerRustDefinitions(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `fn hello(name: &str) -> String { + format!("hi {}", name) +} + +struct MyStruct { + field: i32, +} +` + if err := os.WriteFile(filepath.Join(workspace, "main.rs"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + hits := idx.Search("hello", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'hello'") + } + if hits[0].Kind != "function" { + t.Fatalf("expected function kind, got %q", hits[0].Kind) + } + + hits = idx.Search("MyStruct", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'MyStruct'") + } +} + +func TestTreeSitterIndexerSkipsGoFiles(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `package main +func GoOnlyFunction() {} +` + if err := os.WriteFile(filepath.Join(workspace, "main.go"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + // Go files should be skipped by Tree-sitter indexer + hits := idx.Search("GoOnlyFunction", 10) + if len(hits) > 0 { + t.Fatalf("expected no hits for Go function in Tree-sitter index, got %d", len(hits)) + } +} + +func TestTreeSitterIndexerNoResults(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + if err := os.WriteFile(filepath.Join(workspace, "main.py"), []byte("x = 1\n"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + hits := idx.Search("NonExistentSymbol", 10) + if len(hits) != 0 { + t.Fatalf("expected 0 hits, got %d", len(hits)) + } +} + +func TestTreeSitterIndexerEmptyWorkspace(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + hits := idx.Search("anything", 10) + if len(hits) != 0 { + t.Fatalf("expected 0 hits in empty workspace, got %d", len(hits)) + } +} + +func TestTreeSitterIndexerRefreshDetectsNewFile(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + // Initially empty + hits := idx.Search("hello", 10) + if len(hits) != 0 { + t.Fatalf("expected 0 hits initially, got %d", len(hits)) + } + + // Add a file + if err := os.WriteFile(filepath.Join(workspace, "main.py"), []byte("def hello():\n pass\n"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + if err := idx.Refresh(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("Refresh: %v", err) + } + + hits = idx.Search("hello", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'hello' after refresh") + } +} + +func TestTreeSitterIndexerPrefixSearch(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `def helloWorld(): pass +def helloYou(): pass +def goodbye(): pass +` + if err := os.WriteFile(filepath.Join(workspace, "main.py"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + idx := NewTreeSitterIndexer() + if err := idx.EnsureBuilt(context.Background(), workspace, workspace, readFile); err != nil { + t.Fatalf("EnsureBuilt: %v", err) + } + + // Exact search for "helloWorld" + hits := idx.Search("helloWorld", 10) + if len(hits) == 0 { + t.Fatal("expected to find 'helloWorld'") + } + + // Prefix search "hello" should match both "helloWorld" and "helloYou" + hits = idx.Search("hello", 10) + if len(hits) < 2 { + t.Fatalf("expected at least 2 prefix matches for 'hello', got %d", len(hits)) + } +} diff --git a/internal/context/repository/types.go b/internal/repository/types.go similarity index 80% rename from internal/context/repository/types.go rename to internal/repository/types.go index d9938163..ebac02a2 100644 --- a/internal/context/repository/types.go +++ b/internal/repository/types.go @@ -25,6 +25,21 @@ const ( RetrievalModeSymbol RetrievalMode = "symbol" ) +// FileChangeKind 表示两个 checkpoint 之间单个 path 的变更类别。 +type FileChangeKind string + +const ( + FileChangeAdded FileChangeKind = "added" + FileChangeDeleted FileChangeKind = "deleted" + FileChangeModified FileChangeKind = "modified" +) + +// FileChangeEntry 描述端到端 diff 中单个 path 的变更。 +type FileChangeEntry struct { + Path string + Kind FileChangeKind +} + // Summary 描述当前工作区相对仓库的最小事实快照。 type Summary struct { InGitRepo bool @@ -75,7 +90,7 @@ type RetrievalQuery struct { ContextLines int } -// RetrievalHit 表示单个检索命中的结构化结果。 +// RetrievalHit 表示单个检索命中的结构化结果(兼容旧上下文,含代码片段)。 type RetrievalHit struct { Path string Kind string @@ -96,23 +111,69 @@ type InspectResult struct { ChangedFiles ChangedFilesContext } -// Service 提供轻量仓库摘要、变更上下文与定向检索能力。 -type Service struct { - gitRunner gitCommandRunner - readFile fileReader +// ReadOptions 控制 codebase_read 的读取上限。 +type ReadOptions struct { + MaxBytes int } -type snippetResult struct { - text string - lines int - truncated bool +// ReadResult 表示一次受限文件读取的结果。 +type ReadResult struct { + Path string + Content string + Truncated bool + IsBinary bool + Size int64 +} + +// SearchOptions 控制文本/符号搜索的裁剪策略。 +type SearchOptions struct { + ScopeDir string + Limit int + WholeWord bool +} + +// TextSearchHit 表示文本搜索的单文件命中(硬约束:不返回代码内容)。 +type TextSearchHit struct { + Path string + LineHint int + MatchCount int +} + +// TextSearchResult 表示文本搜索的结构化结果。 +type TextSearchResult struct { + Hits []TextSearchHit + Truncated bool + TotalCount int +} + +// SymbolSearchHit 表示符号搜索的单条命中(硬约束:仅返回位置与签名,不含函数体)。 +type SymbolSearchHit struct { + Path string + LineHint int + Kind string + Signature string +} + +// SymbolSearchResult 表示符号搜索的结构化结果。 +type SymbolSearchResult struct { + Hits []SymbolSearchHit + Truncated bool + TotalCount int +} + +// Service 提供轻量仓库摘要、变更上下文、定向检索与代码库探索能力。 +type Service struct { + gitRunner GitCommandRunner + readFile FileReader + treesitterIndex *TreeSitterIndexer } // NewService 返回默认的轻量仓库服务实现。 func NewService() *Service { return &Service{ - gitRunner: runGitCommand, - readFile: readFile, + gitRunner: runGitCommand, + readFile: readFile, + treesitterIndex: NewTreeSitterIndexer(), } } @@ -278,3 +339,9 @@ func (s *Service) Retrieve(ctx context.Context, workdir string, query RetrievalQ return RetrievalResult{}, errInvalidMode } } + +type snippetResult struct { + text string + lines int + truncated bool +} diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 5a6bb735..f1aa1810 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -2,6 +2,7 @@ package runtime import ( "context" + "errors" "os" "path/filepath" "strings" @@ -776,6 +777,76 @@ func TestCheckpointDiffRejectsMissingStateAndReturnsEmptyWhenNoPreviousSnapshot( } } +func TestCreateEndOfTurnCheckpoint_SetsLastCheckpointID(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + fixture.captureFile(t, "tracked.go", []byte("package main\n")) + + state := newRunState("run-eot-id", fixture.session) + fixture.service.createEndOfTurnCheckpoint(context.Background(), &state, true) + + if state.lastEndOfTurnCheckpointID == "" { + t.Fatal("expected lastEndOfTurnCheckpointID to be set after end-of-turn checkpoint creation") + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want 1", records) + } + wantRef := checkpoint.PerEditCheckpointIDFromRef(records[0].CodeCheckpointRef) + if state.lastEndOfTurnCheckpointID != wantRef { + t.Fatalf("lastEndOfTurnCheckpointID = %q, want %q", state.lastEndOfTurnCheckpointID, wantRef) + } +} + +func TestFindPreviousEndOfTurnCheckpoint(t *testing.T) { + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + {CheckpointID: "cp-skip-current", SessionID: "session-1", Reason: agentsession.CheckpointReasonEndOfTurn, CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-skip-current"), RunID: "current-run", Status: agentsession.CheckpointStatusAvailable}, + {CheckpointID: "cp-skip-reason", SessionID: "session-1", Reason: agentsession.CheckpointReasonCompact, CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-skip-reason"), RunID: "old-run", Status: agentsession.CheckpointStatusAvailable}, + {CheckpointID: "cp-valid", SessionID: "session-1", Reason: agentsession.CheckpointReasonEndOfTurn, CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-valid"), RunID: "old-run", Status: agentsession.CheckpointStatusAvailable}, + }, + } + service := &Service{checkpointStore: spy} + + got := service.findPreviousEndOfTurnCheckpoint(context.Background(), "session-1", "current-run") + if got != "cp-valid" { + t.Fatalf("findPreviousEndOfTurnCheckpoint() = %q, want cp-valid", got) + } + if spy.listSessionID != "session-1" || !spy.listOpts.RestorableOnly || spy.listOpts.Limit != 50 { + t.Fatalf("list opts = %#v, want session-1 restorableOnly=true limit=50", spy.listOpts) + } +} + +func TestFindPreviousEndOfTurnCheckpoint_NoStore(t *testing.T) { + service := &Service{} + if got := service.findPreviousEndOfTurnCheckpoint(context.Background(), "session-1", "run-1"); got != "" { + t.Fatalf("expected empty, got %q", got) + } +} + +func TestFindPreviousEndOfTurnCheckpoint_ListError(t *testing.T) { + spy := &checkpointStoreSpy{listErr: errors.New("db down")} + service := &Service{checkpointStore: spy} + if got := service.findPreviousEndOfTurnCheckpoint(context.Background(), "session-1", "run-1"); got != "" { + t.Fatalf("expected empty on list error, got %q", got) + } +} + +func TestFindPreviousEndOfTurnCheckpoint_SkipsNonPerEditRef(t *testing.T) { + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + {CheckpointID: "cp-no-ref", SessionID: "session-1", Reason: agentsession.CheckpointReasonEndOfTurn, CodeCheckpointRef: "", RunID: "old-run", Status: agentsession.CheckpointStatusAvailable}, + }, + } + service := &Service{checkpointStore: spy} + if got := service.findPreviousEndOfTurnCheckpoint(context.Background(), "session-1", "current-run"); got != "" { + t.Fatalf("expected empty when no per-edit ref available, got %q", got) + } +} + func mustReadRuntimeFile(t *testing.T, path string) []byte { t.Helper() data, err := os.ReadFile(path) diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index a74bef2b..10050a40 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -66,7 +66,11 @@ func (s *Service) createEndOfTurnCheckpoint(ctx context.Context, state *runState defer s.perEditStore.Reset() if err := s.createCheckpointRecord(ctx, session, runID, state, checkpointID, agentsession.CheckpointReasonEndOfTurn); err != nil { log.Printf("checkpoint: end-of-turn record: %v", err) + return } + state.mu.Lock() + state.lastEndOfTurnCheckpointID = checkpointID + state.mu.Unlock() } // createCheckpointRecord 写入 SQLite checkpoint 记录 + session 快照,并发出 EventCheckpointCreated。 @@ -191,3 +195,32 @@ func (s *Service) createSessionOnlyCheckpoint( }) return nil } + +// findPreviousEndOfTurnCheckpoint 查询指定 session 中、不属于当前 run 的最新可用 end_of_turn checkpoint。 +// 用于 run-scoped diff 的 baseline 定位;找不到时返回空字符串,不报错。 +func (s *Service) findPreviousEndOfTurnCheckpoint(ctx context.Context, sessionID string, currentRunID string) string { + if s.checkpointStore == nil { + return "" + } + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{ + Limit: 50, + RestorableOnly: true, + }) + if err != nil { + log.Printf("checkpoint: find previous end-of-turn list failed: %v", err) + return "" + } + for _, r := range records { + if r.Reason != agentsession.CheckpointReasonEndOfTurn { + continue + } + if !checkpoint.IsPerEditRef(r.CodeCheckpointRef) { + continue + } + if r.RunID == currentRunID { + continue + } + return checkpoint.PerEditCheckpointIDFromRef(r.CodeCheckpointRef) + } + return "" +} diff --git a/internal/runtime/events.go b/internal/runtime/events.go index ed05c4df..55b62086 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -432,6 +432,8 @@ const ( EventCheckpointUndoRestore EventType = "checkpoint_undo_restore" // EventBashSideEffect 表示 bash 命令在 workdir 内产生了文件变更。 EventBashSideEffect EventType = "bash_side_effect" + // EventRunDiffSummary 表示一次完整 run 的端到端代码变更摘要已生成。 + EventRunDiffSummary EventType = "run_diff_summary" ) // TokenUsagePayload 承载单轮 token 用量统计。 @@ -500,6 +502,14 @@ type ToolDiffPayload struct { Diffs []FileDiffEntry `json:"diffs,omitempty"` } +// RunDiffSummaryPayload 描述一次完整 run 结束时的端到端代码变更摘要。 +type RunDiffSummaryPayload struct { + FromCheckpointID string `json:"from_checkpoint_id,omitempty"` + ToCheckpointID string `json:"to_checkpoint_id,omitempty"` + Diff string `json:"diff,omitempty"` + ChangedFiles []FileDiffEntry `json:"changed_files,omitempty"` +} + // BashSideEffectPayload 描述 bash 命令在 workdir 内的文件变更。 type BashSideEffectPayload struct { ToolCallID string `json:"tool_call_id"` diff --git a/internal/runtime/repository_context.go b/internal/runtime/repository_context.go index be3ebab7..143705c0 100644 --- a/internal/runtime/repository_context.go +++ b/internal/runtime/repository_context.go @@ -3,35 +3,15 @@ package runtime import ( "context" "errors" - "os" - "regexp" "strings" agentcontext "neo-code/internal/context" - "neo-code/internal/context/repository" + "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" - "neo-code/internal/security" ) -const ( - maxAutoChangedFilesCount = 20 - maxAutoSnippetChangedFilesCount = 5 - defaultAutoChangedFilesLimit = 10 - defaultAutoChangedFilesWithDiff = 5 - defaultAutoPathRetrievalLimit = 1 - defaultAutoSymbolRetrievalLimit = 3 - defaultAutoTextRetrievalLimit = 5 - defaultAutoRetrievalContextLines = 4 - defaultAutoTextRetrievalContext = 3 -) - -var ( - pathAnchorPattern = regexp.MustCompile(`(?i)(?:[a-z0-9_.-]+[\\/])*[a-z0-9_.-]+\.(go|md|ya?ml|json|toml|txt|sh)\b`) - symbolAnchorPattern = regexp.MustCompile(`\b[A-Z][A-Za-z0-9_]{2,}\b`) - quotedTextPattern = regexp.MustCompile("`([^`]+)`|\"([^\"]+)\"|'([^']+)'") -) - -// buildRepositoryContext 按当前轮输入意图统一编排 repository summary、changed-files 与 retrieval 投影。 +// buildRepositoryContext 返回最小 Git 摘要(迁移期保留),不再自动注入 changed-files 或 retrieval。 +// 模型应通过 git_* / codebase_* 工具主动探索仓库。 func (s *Service) buildRepositoryContext( ctx context.Context, state *runState, @@ -44,45 +24,18 @@ func (s *Service) buildRepositoryContext( return nil, agentcontext.RepositoryContext{}, nil } - latestUserText := latestUserText(state.session.Messages) repoService := s.repositoryFacts() - repoContext := agentcontext.RepositoryContext{} - var summarySection *agentcontext.RepositorySummarySection - - includeChangedFiles := latestUserText != "" && (shouldAutoInjectChangedFiles(latestUserText) || mentionsFixOrReviewIntent(latestUserText)) - includeChangedSnippets := latestUserText != "" && shouldAutoIncludeChangedFileSnippets(latestUserText) - inspectResult, inspectErr := repoService.Inspect(ctx, activeWorkdir, repository.InspectOptions{ - ChangedFilesLimit: changedFilesLimitForUserText(includeChangedSnippets), - IncludeChangedFileSnippets: includeChangedSnippets, - ChangedFileSnippetFileCountLimit: maxAutoSnippetChangedFilesCount, - }) + inspectResult, inspectErr := repoService.Inspect(ctx, activeWorkdir, repository.InspectOptions{}) if inspectErr != nil { if isRepositoryContextFatalError(inspectErr) { return nil, agentcontext.RepositoryContext{}, inspectErr } s.emitRepositoryContextUnavailable(ctx, state, "summary", "", inspectErr) - } else { - summarySection = projectRepositorySummary(inspectResult.Summary) - if includeChangedFiles { - if changedFiles := changedFilesProjectionForUserText(latestUserText, inspectResult.ChangedFiles); changedFiles != nil { - repoContext.ChangedFiles = changedFiles - } - } - } - - if query, ok := autoRetrievalQueryFromUserText(activeWorkdir, latestUserText); ok { - retrieval, retrievalErr := s.buildRetrievalContextForQuery(ctx, repoService, activeWorkdir, query) - if retrievalErr != nil { - if isRepositoryContextFatalError(retrievalErr) { - return nil, agentcontext.RepositoryContext{}, retrievalErr - } - s.emitRepositoryContextUnavailable(ctx, state, "retrieval", string(query.Mode), retrievalErr) - } else { - repoContext.Retrieval = retrieval - } + return nil, agentcontext.RepositoryContext{}, nil } - return summarySection, repoContext, nil + summarySection := projectRepositorySummary(inspectResult.Summary) + return summarySection, agentcontext.RepositoryContext{}, nil } // repositoryFacts 返回 runtime 当前使用的 repository 事实服务,并在缺省时回落到默认实现。 @@ -93,13 +46,6 @@ func (s *Service) repositoryFacts() repositoryFactService { return repository.NewService() } -func changedFilesLimitForUserText(includeSnippets bool) int { - if includeSnippets { - return defaultAutoChangedFilesWithDiff - } - return defaultAutoChangedFilesLimit -} - func projectRepositorySummary(summary repository.Summary) *agentcontext.RepositorySummarySection { if !summary.InGitRepo { return nil @@ -113,45 +59,6 @@ func projectRepositorySummary(summary repository.Summary) *agentcontext.Reposito } } -func changedFilesProjectionForUserText(userText string, changed repository.ChangedFilesContext) *agentcontext.RepositoryChangedFilesSection { - explicitChangedFilesIntent := shouldAutoInjectChangedFiles(userText) - if len(changed.Files) == 0 { - return nil - } - if !explicitChangedFilesIntent && (changed.TotalCount <= 0 || changed.TotalCount > maxAutoChangedFilesCount) { - return nil - } - return &agentcontext.RepositoryChangedFilesSection{ - Files: append([]repository.ChangedFile(nil), changed.Files...), - Truncated: changed.Truncated, - ReturnedCount: changed.ReturnedCount, - TotalCount: changed.TotalCount, - } -} - -// buildRetrievalContextForQuery 基于已解析出的显式锚点执行单次定向检索并投影为 context 结构。 -func (s *Service) buildRetrievalContextForQuery( - ctx context.Context, - repoService repositoryFactService, - workdir string, - query repository.RetrievalQuery, -) (*agentcontext.RepositoryRetrievalSection, error) { - result, err := repoService.Retrieve(ctx, workdir, query) - if err != nil { - return nil, err - } - if len(result.Hits) == 0 { - return nil, nil - } - - return &agentcontext.RepositoryRetrievalSection{ - Hits: append([]repository.RetrievalHit(nil), result.Hits...), - Truncated: result.Truncated, - Mode: string(query.Mode), - Query: query.Value, - }, nil -} - // emitRepositoryContextUnavailable 记录 repository 事实获取失败但已降级为空上下文的可观测事件。 func (s *Service) emitRepositoryContextUnavailable( ctx context.Context, @@ -199,179 +106,6 @@ func extractTextParts(parts []providertypes.ContentPart) string { return strings.TrimSpace(strings.Join(fragments, "\n")) } -// shouldAutoInjectChangedFiles 判断本轮是否应优先注入 changed-files 摘要。 -func shouldAutoInjectChangedFiles(userText string) bool { - lower := strings.ToLower(strings.TrimSpace(userText)) - if lower == "" { - return false - } - keywords := []string{ - "当前改动", - "这次修改", - "changed files", - "current diff", - "git diff", - "review 我的改动", - "review my changes", - "我的改动", - "本次改动", - "未提交", - } - for _, keyword := range keywords { - if strings.Contains(lower, keyword) { - return true - } - } - return false -} - -// shouldAutoIncludeChangedFileSnippets 仅在小变更集的 review/fix 语义下升级为 snippet 注入。 -func shouldAutoIncludeChangedFileSnippets(userText string) bool { - lower := strings.ToLower(strings.TrimSpace(userText)) - if lower == "" { - return false - } - keywords := []string{ - "review", - "diff", - "patch", - "解释改动", - "explain changes", - "fix", - "修复", - } - for _, keyword := range keywords { - if strings.Contains(lower, keyword) { - return true - } - } - return false -} - -// mentionsFixOrReviewIntent 判断问题是否属于更依赖当前工作树状态的 fix/review 类型任务。 -func mentionsFixOrReviewIntent(userText string) bool { - lower := strings.ToLower(strings.TrimSpace(userText)) - if lower == "" { - return false - } - keywords := []string{ - "fix", - "debug", - "review", - "修复", - "排查", - "debugging", - "bug", - } - for _, keyword := range keywords { - if strings.Contains(lower, keyword) { - return true - } - } - return false -} - -// autoRetrievalQueryFromUserText 基于显式锚点抽取本轮至多一组自动 retrieval 请求。 -func autoRetrievalQueryFromUserText(workdir string, userText string) (repository.RetrievalQuery, bool) { - if pathQuery, ok := autoPathRetrievalQuery(workdir, userText); ok { - return pathQuery, true - } - if symbolQuery, ok := autoSymbolRetrievalQuery(userText); ok { - return symbolQuery, true - } - if textQuery, ok := autoTextRetrievalQuery(userText); ok { - return textQuery, true - } - return repository.RetrievalQuery{}, false -} - -// autoPathRetrievalQuery 从文本中提取最明确的路径锚点,并映射为 path 模式检索。 -func autoPathRetrievalQuery(workdir string, userText string) (repository.RetrievalQuery, bool) { - match := pathAnchorPattern.FindString(strings.TrimSpace(userText)) - if strings.TrimSpace(match) == "" { - return repository.RetrievalQuery{}, false - } - candidate := strings.Trim(match, "`\"'") - if !workspacePathAnchorExists(workdir, candidate) { - return repository.RetrievalQuery{}, false - } - return repository.RetrievalQuery{ - Mode: repository.RetrievalModePath, - Value: candidate, - Limit: defaultAutoPathRetrievalLimit, - ContextLines: defaultAutoRetrievalContextLines, - }, true -} - -func workspacePathAnchorExists(workdir string, path string) bool { - if strings.TrimSpace(workdir) == "" || strings.TrimSpace(path) == "" { - return false - } - _, target, err := security.ResolveWorkspacePath(workdir, path) - if err != nil { - return false - } - info, err := os.Stat(target) - if err != nil { - return false - } - return !info.IsDir() -} - -// autoSymbolRetrievalQuery 仅在句式明显指向符号定义/实现时抽取 Go-first 符号检索。 -func autoSymbolRetrievalQuery(userText string) (repository.RetrievalQuery, bool) { - lower := strings.ToLower(userText) - if !(strings.Contains(lower, "定义") || - strings.Contains(lower, "实现") || - strings.Contains(lower, "在哪") || - strings.Contains(lower, "where is") || - strings.Contains(lower, "explain") || - strings.Contains(lower, "look at")) { - return repository.RetrievalQuery{}, false - } - - matches := quotedTextPattern.FindAllStringSubmatch(userText, -1) - for _, match := range matches { - for _, group := range match[1:] { - candidate := strings.TrimSpace(group) - if candidate == "" || !symbolAnchorPattern.MatchString(candidate) || candidate != symbolAnchorPattern.FindString(candidate) { - continue - } - return repository.RetrievalQuery{ - Mode: repository.RetrievalModeSymbol, - Value: candidate, - Limit: defaultAutoSymbolRetrievalLimit, - ContextLines: defaultAutoRetrievalContextLines, - }, true - } - } - return repository.RetrievalQuery{}, false -} - -// autoTextRetrievalQuery 只对显式包裹的关键字做一次有限文本检索,避免宽泛问题误触发。 -func autoTextRetrievalQuery(userText string) (repository.RetrievalQuery, bool) { - matches := quotedTextPattern.FindAllStringSubmatch(userText, -1) - for _, match := range matches { - candidate := "" - for _, group := range match[1:] { - if strings.TrimSpace(group) != "" { - candidate = strings.TrimSpace(group) - break - } - } - if candidate == "" || len([]rune(candidate)) < 3 || strings.Contains(candidate, "/") || strings.Contains(candidate, "\\") { - continue - } - return repository.RetrievalQuery{ - Mode: repository.RetrievalModeText, - Value: candidate, - Limit: defaultAutoTextRetrievalLimit, - ContextLines: defaultAutoTextRetrievalContext, - }, true - } - return repository.RetrievalQuery{}, false -} - // isRepositoryContextFatalError 只把上下文取消类错误视作主链应立即返回的致命错误。 func isRepositoryContextFatalError(err error) bool { return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) diff --git a/internal/runtime/repository_context_additional_test.go b/internal/runtime/repository_context_additional_test.go index 5e8968a7..d2a3bf37 100644 --- a/internal/runtime/repository_context_additional_test.go +++ b/internal/runtime/repository_context_additional_test.go @@ -3,10 +3,9 @@ package runtime import ( "context" "errors" - "path/filepath" "testing" - "neo-code/internal/context/repository" + "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" ) @@ -50,41 +49,11 @@ func TestBuildRepositoryContextEarlyReturnAndFatalPaths(t *testing.T) { if _, _, err := fatalFromInspect.buildRepositoryContext(context.Background(), &state, state.session.Workdir); !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected fatal inspect error, got %v", err) } - - workdir := t.TempDir() - mustRuntimeWriteFile(t, filepath.Join(workdir, "README.md"), "# readme\n") - fatalFromRetrieval := &Service{ - repositoryService: &stubRepositoryFactService{ - inspectFn: func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) { - return repository.InspectResult{Summary: repository.Summary{InGitRepo: true, Branch: "main"}}, nil - }, - retrieveFn: func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) { - return repository.RetrievalResult{}, context.Canceled - }, - }, - events: make(chan RuntimeEvent, 8), - } - retrievalState := newRepositoryTestState(workdir, "看看 README.md") - _, _, err := fatalFromRetrieval.buildRepositoryContext(context.Background(), &retrievalState, retrievalState.session.Workdir) - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected fatal retrieval error, got %v", err) - } } func TestRepositoryContextHelpers(t *testing.T) { t.Parallel() - workdir := t.TempDir() - mustRuntimeWriteFile(t, filepath.Join(workdir, "README.md"), "# readme\n") - mustRuntimeWriteFile(t, filepath.Join(workdir, "internal", "runtime", "run.go"), "package runtime\n") - - if got := changedFilesLimitForUserText(false); got != defaultAutoChangedFilesLimit { - t.Fatalf("changedFilesLimitForUserText(false) = %d", got) - } - if got := changedFilesLimitForUserText(true); got != defaultAutoChangedFilesWithDiff { - t.Fatalf("changedFilesLimitForUserText(true) = %d", got) - } - if projectRepositorySummary(repository.Summary{}) != nil { t.Fatalf("expected nil summary projection for non-git") } @@ -99,74 +68,6 @@ func TestRepositoryContextHelpers(t *testing.T) { t.Fatalf("unexpected summary projection: %+v", summary) } - if changedFilesProjectionForUserText("解释架构", repository.ChangedFilesContext{ - Files: []repository.ChangedFile{{Path: "a.go", Status: repository.StatusModified}}, - ReturnedCount: 1, - TotalCount: maxAutoChangedFilesCount + 1, - }) != nil { - t.Fatalf("expected implicit large changed-files set to be dropped") - } - if projection := changedFilesProjectionForUserText("review 我的改动", repository.ChangedFilesContext{ - Files: []repository.ChangedFile{{Path: "a.go", Status: repository.StatusModified}}, - ReturnedCount: 1, - TotalCount: maxAutoChangedFilesCount + 1, - Truncated: true, - }); projection == nil || !projection.Truncated { - t.Fatalf("expected explicit changed-files projection, got %+v", projection) - } - - if query, ok := autoRetrievalQueryFromUserText(workdir, "解释这个模块"); ok { - t.Fatalf("expected no query, got %+v", query) - } - if query, ok := autoPathRetrievalQuery(workdir, "`internal/runtime/run.go`"); !ok || query.Mode != repository.RetrievalModePath { - t.Fatalf("autoPathRetrievalQuery(subdir) = (%+v, %t)", query, ok) - } - if query, ok := autoPathRetrievalQuery(workdir, "README.md"); !ok || query.Value != "README.md" { - t.Fatalf("autoPathRetrievalQuery(root) = (%+v, %t)", query, ok) - } - if _, ok := autoPathRetrievalQuery(workdir, "missing.go"); ok { - t.Fatalf("expected missing root file to not trigger path retrieval") - } - if workspacePathAnchorExists(workdir, "README.md") == false { - t.Fatalf("expected README.md to exist as anchor") - } - if workspacePathAnchorExists(workdir, "missing.go") { - t.Fatalf("expected missing anchor to be rejected") - } - - if _, ok := autoSymbolRetrievalQuery("BuildWidget 在吗"); ok { - t.Fatalf("expected symbol query to require intent words") - } - if _, ok := autoSymbolRetrievalQuery("where is BuildWidget"); ok { - t.Fatalf("expected bare capitalized word to not trigger symbol retrieval") - } - if query, ok := autoSymbolRetrievalQuery("where is `BuildWidget`"); !ok || query.Value != "BuildWidget" { - t.Fatalf("autoSymbolRetrievalQuery() = (%+v, %t)", query, ok) - } - - if _, ok := autoTextRetrievalQuery("find `internal/runtime/run.go`"); ok { - t.Fatalf("expected path-like quoted text to be ignored") - } - if _, ok := autoTextRetrievalQuery("find `go`"); ok { - t.Fatalf("expected short quoted text to be ignored") - } - if query, ok := autoTextRetrievalQuery("find `permission_requested`"); !ok || query.Value != "permission_requested" { - t.Fatalf("autoTextRetrievalQuery() = (%+v, %t)", query, ok) - } - - if query, ok := autoRetrievalQueryFromUserText(workdir, "看看 README.md 的 BuildWidget 和 `permission_requested`"); !ok || query.Mode != repository.RetrievalModePath { - t.Fatalf("expected path query to win priority, got (%+v, %t)", query, ok) - } - - if !shouldAutoInjectChangedFiles("请看 changed files") || shouldAutoInjectChangedFiles("just chat") { - t.Fatalf("shouldAutoInjectChangedFiles() mismatch") - } - if !shouldAutoIncludeChangedFileSnippets("please review diff") || shouldAutoIncludeChangedFileSnippets("just explain") { - t.Fatalf("shouldAutoIncludeChangedFileSnippets() mismatch") - } - if !mentionsFixOrReviewIntent("debug this bug") || mentionsFixOrReviewIntent("architecture overview") { - t.Fatalf("mentionsFixOrReviewIntent() mismatch") - } if !isRepositoryContextFatalError(context.Canceled) || !isRepositoryContextFatalError(context.DeadlineExceeded) || isRepositoryContextFatalError(errors.New("x")) { t.Fatalf("isRepositoryContextFatalError() mismatch") } diff --git a/internal/runtime/repository_context_test.go b/internal/runtime/repository_context_test.go index 898d39bb..cf7379c3 100644 --- a/internal/runtime/repository_context_test.go +++ b/internal/runtime/repository_context_test.go @@ -3,24 +3,21 @@ package runtime import ( "context" "errors" - "os" - "path/filepath" "testing" agentcontext "neo-code/internal/context" - "neo-code/internal/context/repository" + "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" "neo-code/internal/tools" ) type stubRepositoryFactService struct { - inspectFn func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) - retrieveFn func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) - inspectCalls int - retrieveCalls int - lastInspectOpts repository.InspectOptions - lastRetrieveQuery repository.RetrievalQuery + inspectFn func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) + retrieveFn func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) + inspectCalls int + retrieveCalls int + lastInspectOpts repository.InspectOptions } func (s *stubRepositoryFactService) Inspect( @@ -42,7 +39,6 @@ func (s *stubRepositoryFactService) Retrieve( query repository.RetrievalQuery, ) (repository.RetrievalResult, error) { s.retrieveCalls++ - s.lastRetrieveQuery = query if s.retrieveFn != nil { return s.retrieveFn(ctx, workdir, query) } @@ -59,29 +55,7 @@ func newRepositoryTestState(workdir string, text string) runState { return newRunState("run-repository-context", session) } -func TestBuildRepositoryContextSkipsWithoutAnchors(t *testing.T) { - t.Parallel() - - repoService := &stubRepositoryFactService{} - state := newRepositoryTestState(t.TempDir(), "解释一下 runtime 架构") - service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - - summary, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if summary != nil { - t.Fatalf("expected nil summary for non-git inspect result, got %+v", summary) - } - if repoContext.ChangedFiles != nil || repoContext.Retrieval != nil { - t.Fatalf("expected empty repository context, got %+v", repoContext) - } - if repoService.inspectCalls != 1 || repoService.retrieveCalls != 0 { - t.Fatalf("expected inspect once and no retrieval, got inspect=%d retrieve=%d", repoService.inspectCalls, repoService.retrieveCalls) - } -} - -func TestBuildRepositoryContextUsesInspectForSummaryAndChangedFiles(t *testing.T) { +func TestBuildRepositoryContextReturnsSummaryOnly(t *testing.T) { t.Parallel() repoService := &stubRepositoryFactService{ @@ -94,13 +68,6 @@ func TestBuildRepositoryContextUsesInspectForSummaryAndChangedFiles(t *testing.T Ahead: 2, Behind: 1, }, - ChangedFiles: repository.ChangedFilesContext{ - Files: []repository.ChangedFile{ - {Path: "internal/runtime/run.go", Status: repository.StatusModified, Snippet: "@@ snippet"}, - }, - ReturnedCount: 1, - TotalCount: 1, - }, }, nil }, } @@ -114,215 +81,36 @@ func TestBuildRepositoryContextUsesInspectForSummaryAndChangedFiles(t *testing.T if summary == nil || summary.Branch != "feature/repository" || !summary.Dirty || summary.Ahead != 2 || summary.Behind != 1 { t.Fatalf("unexpected summary projection: %+v", summary) } - if repoContext.ChangedFiles == nil || len(repoContext.ChangedFiles.Files) != 1 { - t.Fatalf("expected changed files context, got %+v", repoContext.ChangedFiles) - } - if repoService.inspectCalls != 1 { - t.Fatalf("expected a single inspect call, got %d", repoService.inspectCalls) - } - if !repoService.lastInspectOpts.IncludeChangedFileSnippets { - t.Fatalf("expected snippets to be enabled, got %+v", repoService.lastInspectOpts) - } - if repoService.lastInspectOpts.ChangedFilesLimit != defaultAutoChangedFilesWithDiff { - t.Fatalf("expected changed-files limit %d, got %+v", defaultAutoChangedFilesWithDiff, repoService.lastInspectOpts) - } - if repoService.lastInspectOpts.ChangedFileSnippetFileCountLimit != maxAutoSnippetChangedFilesCount { - t.Fatalf("expected snippet file count limit %d, got %+v", maxAutoSnippetChangedFilesCount, repoService.lastInspectOpts) - } -} - -func TestBuildRepositoryContextSkipsImplicitLargeChangedSet(t *testing.T) { - t.Parallel() - - repoService := &stubRepositoryFactService{ - inspectFn: func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) { - return repository.InspectResult{ - Summary: repository.Summary{InGitRepo: true, Branch: "main", Dirty: true}, - ChangedFiles: repository.ChangedFilesContext{ - Files: []repository.ChangedFile{{Path: "internal/runtime/run.go", Status: repository.StatusModified}}, - ReturnedCount: 1, - TotalCount: maxAutoChangedFilesCount + 1, - }, - }, nil - }, - } - state := newRepositoryTestState(t.TempDir(), "fix 这个 bug") - service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - - _, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if repoContext.ChangedFiles != nil { - t.Fatalf("expected implicit large changed set to be skipped, got %+v", repoContext.ChangedFiles) + if repoContext.ChangedFiles != nil || repoContext.Retrieval != nil { + t.Fatalf("expected empty repository context, got %+v", repoContext) } if repoService.inspectCalls != 1 { t.Fatalf("expected a single inspect call, got %d", repoService.inspectCalls) } -} - -func TestBuildRepositoryContextInjectsExplicitLargeChangedSet(t *testing.T) { - t.Parallel() - - repoService := &stubRepositoryFactService{ - inspectFn: func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) { - return repository.InspectResult{ - Summary: repository.Summary{InGitRepo: true, Branch: "main", Dirty: true}, - ChangedFiles: repository.ChangedFilesContext{ - Files: []repository.ChangedFile{{Path: "internal/runtime/run.go", Status: repository.StatusModified}}, - ReturnedCount: 1, - TotalCount: maxAutoChangedFilesCount + 5, - Truncated: true, - }, - }, nil - }, - } - state := newRepositoryTestState(t.TempDir(), "review 我的改动") - service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - - _, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if repoContext.ChangedFiles == nil || repoContext.ChangedFiles.TotalCount <= maxAutoChangedFilesCount { - t.Fatalf("expected explicit changed-files intent to keep truncated large set, got %+v", repoContext.ChangedFiles) - } -} - -func TestBuildRepositoryContextUsesPathRetrievalWithHighestPriority(t *testing.T) { - t.Parallel() - - workdir := t.TempDir() - mustRuntimeWriteFile(t, filepath.Join(workdir, "internal", "runtime", "run.go"), "package runtime\n") - repoService := &stubRepositoryFactService{ - retrieveFn: func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) { - return repository.RetrievalResult{Hits: []repository.RetrievalHit{{ - Path: "internal/runtime/run.go", - Kind: string(query.Mode), - SymbolOrQuery: query.Value, - Snippet: "func ...", - LineHint: 1, - }}, Truncated: true}, nil - }, - } - state := newRepositoryTestState(workdir, "看看 internal/runtime/run.go 里 ExecuteSystemTool 是怎么处理的") - service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - - _, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if repoContext.Retrieval == nil { - t.Fatalf("expected retrieval context") - } - if repoService.lastRetrieveQuery.Mode != repository.RetrievalModePath { - t.Fatalf("expected path retrieval, got %+v", repoService.lastRetrieveQuery) - } - if !repoContext.Retrieval.Truncated { - t.Fatalf("expected retrieval truncation to propagate") + if repoService.lastInspectOpts.ChangedFilesLimit != 0 { + t.Fatalf("expected no changed-files limit, got %+v", repoService.lastInspectOpts) } } -func TestBuildRepositoryContextSupportsRootFilePathAnchor(t *testing.T) { +func TestBuildRepositoryContextSkipsWithoutGitRepo(t *testing.T) { t.Parallel() - workdir := t.TempDir() - mustRuntimeWriteFile(t, filepath.Join(workdir, "README.md"), "# readme\n") - repoService := &stubRepositoryFactService{ - retrieveFn: func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) { - return repository.RetrievalResult{Hits: []repository.RetrievalHit{{Path: "README.md", Kind: string(query.Mode), LineHint: 1}}}, nil - }, - } - state := newRepositoryTestState(workdir, "解释一下 README.md") + repoService := &stubRepositoryFactService{} + state := newRepositoryTestState(t.TempDir(), "解释一下 runtime 架构") service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - _, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) + summary, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) if err != nil { t.Fatalf("buildRepositoryContext() error = %v", err) } - if repoContext.Retrieval == nil || repoService.lastRetrieveQuery.Mode != repository.RetrievalModePath || repoService.lastRetrieveQuery.Value != "README.md" { - t.Fatalf("expected root path retrieval, got context=%+v query=%+v", repoContext.Retrieval, repoService.lastRetrieveQuery) - } -} - -func TestBuildRepositoryContextUsesSymbolAndTextRetrievalAnchors(t *testing.T) { - t.Parallel() - - t.Run("symbol anchor", func(t *testing.T) { - repoService := &stubRepositoryFactService{ - retrieveFn: func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) { - return repository.RetrievalResult{Hits: []repository.RetrievalHit{{Path: "internal/runtime/system_tool.go", Kind: string(query.Mode), LineHint: 8}}}, nil - }, - } - state := newRepositoryTestState(t.TempDir(), "where is `ExecuteSystemTool`") - service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - - _, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if repoContext.Retrieval == nil || repoService.lastRetrieveQuery.Mode != repository.RetrievalModeSymbol { - t.Fatalf("expected symbol retrieval, got context=%+v query=%+v", repoContext.Retrieval, repoService.lastRetrieveQuery) - } - }) - - t.Run("quoted text anchor", func(t *testing.T) { - repoService := &stubRepositoryFactService{ - retrieveFn: func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) { - return repository.RetrievalResult{Hits: []repository.RetrievalHit{{Path: "internal/runtime/events.go", Kind: string(query.Mode), LineHint: 14}}}, nil - }, - } - state := newRepositoryTestState(t.TempDir(), "找 `permission_requested` 在哪里处理") - service := &Service{repositoryService: repoService, events: make(chan RuntimeEvent, 8)} - - _, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if repoContext.Retrieval == nil || repoService.lastRetrieveQuery.Mode != repository.RetrievalModeText { - t.Fatalf("expected text retrieval, got context=%+v query=%+v", repoContext.Retrieval, repoService.lastRetrieveQuery) - } - }) -} - -func TestPrepareTurnBudgetSnapshotPassesRepositoryContextToBuilder(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - builder := &stubContextBuilder{} - repoService := &stubRepositoryFactService{ - inspectFn: func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) { - return repository.InspectResult{ - Summary: repository.Summary{InGitRepo: true, Branch: "main", Dirty: true}, - ChangedFiles: repository.ChangedFilesContext{ - Files: []repository.ChangedFile{{Path: "internal/runtime/run.go", Status: repository.StatusModified}}, - ReturnedCount: 1, - TotalCount: 1, - }, - }, nil - }, - } - - service := &Service{ - configManager: manager, - contextBuilder: builder, - toolManager: tools.NewRegistry(), - repositoryService: repoService, - providerFactory: &scriptedProviderFactory{provider: &scriptedProvider{}}, - events: make(chan RuntimeEvent, 8), - } - state := newRepositoryTestState(t.TempDir(), "请 review 当前改动") - - if _, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err != nil { - t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) - } else if rebuilt { - t.Fatalf("expected rebuilt=false") + if summary != nil { + t.Fatalf("expected nil summary for non-git inspect result, got %+v", summary) } - if builder.lastInput.Repository.ChangedFiles == nil { - t.Fatalf("expected builder to receive changed files context") + if repoContext.ChangedFiles != nil || repoContext.Retrieval != nil { + t.Fatalf("expected empty repository context, got %+v", repoContext) } - if builder.lastInput.RepositorySummary == nil || builder.lastInput.RepositorySummary.Branch != "main" { - t.Fatalf("expected builder to receive repository summary, got %+v", builder.lastInput.RepositorySummary) + if repoService.inspectCalls != 1 || repoService.retrieveCalls != 0 { + t.Fatalf("expected inspect once and no retrieval, got inspect=%d retrieve=%d", repoService.inspectCalls, repoService.retrieveCalls) } } @@ -366,62 +154,38 @@ func TestBuildRepositoryContextEmitsUnavailableEventForSummaryFailure(t *testing t.Fatalf("expected repository unavailable event payload") } -func TestBuildRepositoryContextEmitsUnavailableEventForRetrievalFailure(t *testing.T) { +func TestPrepareTurnBudgetSnapshotPassesRepositorySummaryToBuilder(t *testing.T) { t.Parallel() - workdir := t.TempDir() - mustRuntimeWriteFile(t, filepath.Join(workdir, "README.md"), "# readme\n") + manager := newRuntimeConfigManager(t) + builder := &stubContextBuilder{} repoService := &stubRepositoryFactService{ inspectFn: func(ctx context.Context, workdir string, opts repository.InspectOptions) (repository.InspectResult, error) { return repository.InspectResult{ - Summary: repository.Summary{InGitRepo: true, Branch: "main"}, + Summary: repository.Summary{InGitRepo: true, Branch: "main", Dirty: true}, }, nil }, - retrieveFn: func(ctx context.Context, workdir string, query repository.RetrievalQuery) (repository.RetrievalResult, error) { - return repository.RetrievalResult{}, errors.New("read failed") - }, } + service := &Service{ + configManager: manager, + contextBuilder: builder, + toolManager: tools.NewRegistry(), repositoryService: repoService, + providerFactory: &scriptedProviderFactory{provider: &scriptedProvider{}}, events: make(chan RuntimeEvent, 8), } - state := newRepositoryTestState(workdir, "看看 README.md") - - summary, repoContext, err := service.buildRepositoryContext(context.Background(), &state, state.session.Workdir) - if err != nil { - t.Fatalf("buildRepositoryContext() error = %v", err) - } - if summary == nil || summary.Branch != "main" { - t.Fatalf("expected summary to survive retrieval failure, got %+v", summary) - } - if repoContext != (agentcontext.RepositoryContext{}) { - t.Fatalf("expected empty repository context on retrieval failure, got %+v", repoContext) - } + state := newRepositoryTestState(t.TempDir(), "请 review 当前改动") - events := collectRuntimeEvents(service.Events()) - assertEventContains(t, events, EventRepositoryContextUnavailable) - for _, event := range events { - if event.Type != EventRepositoryContextUnavailable { - continue - } - payload, ok := event.Payload.(RepositoryContextUnavailablePayload) - if !ok { - t.Fatalf("payload type = %T, want RepositoryContextUnavailablePayload", event.Payload) - } - if payload.Stage != "retrieval" || payload.Mode != "path" || payload.Reason == "" { - t.Fatalf("unexpected payload: %+v", payload) - } - return + if _, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err != nil { + t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) + } else if rebuilt { + t.Fatalf("expected rebuilt=false") } - t.Fatalf("expected repository unavailable event payload") -} - -func mustRuntimeWriteFile(t *testing.T, path string, content string) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - t.Fatalf("MkdirAll() error = %v", err) + if builder.lastInput.Repository.ChangedFiles != nil { + t.Fatalf("expected builder to receive no changed files context") } - if err := os.WriteFile(path, []byte(content), 0o644); err != nil { - t.Fatalf("WriteFile() error = %v", err) + if builder.lastInput.RepositorySummary == nil || builder.lastInput.RepositorySummary.Branch != "main" { + t.Fatalf("expected builder to receive repository summary, got %+v", builder.lastInput.RepositorySummary) } } diff --git a/internal/runtime/run.go b/internal/runtime/run.go index fed80f1f..4f1a8cee 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -101,6 +101,23 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } s.updateResumeCheckpoint(runCtx, statePtr, "stopped", completion) } + if statePtr != nil && s.perEditStore != nil && statePtr.baselineCheckpointID != "" && statePtr.lastEndOfTurnCheckpointID != "" { + diffStr, _ := s.perEditStore.Diff(context.Background(), statePtr.baselineCheckpointID, statePtr.lastEndOfTurnCheckpointID) + files, _ := s.perEditStore.ChangedFiles(context.Background(), statePtr.baselineCheckpointID, statePtr.lastEndOfTurnCheckpointID) + var changedFiles []FileDiffEntry + for _, f := range files { + changedFiles = append(changedFiles, FileDiffEntry{ + Path: f.Path, + Kind: string(f.Kind), + }) + } + s.emitRunScopedOptional(EventRunDiffSummary, statePtr, RunDiffSummaryPayload{ + FromCheckpointID: statePtr.baselineCheckpointID, + ToCheckpointID: statePtr.lastEndOfTurnCheckpointID, + Diff: diffStr, + ChangedFiles: changedFiles, + }) + } s.emitRunTermination(runCtx, input, statePtr, err) }() ctx = runCtx @@ -187,6 +204,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.updateResumeCheckpoint(ctx, &state, "plan", "") maxTurns := resolveRuntimeMaxTurns(initialCfg.Runtime) + state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID) for turn := 0; ; turn++ { if turn >= maxTurns { state.maxTurnsReached = true diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 0cf760dc..0f2b5dc3 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -13,7 +13,7 @@ import ( "neo-code/internal/config" agentcontext "neo-code/internal/context" contextcompact "neo-code/internal/context/compact" - "neo-code/internal/context/repository" + "neo-code/internal/repository" "neo-code/internal/provider" "neo-code/internal/provider/builtin" providertypes "neo-code/internal/provider/types" diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index d52ae0cf..4dd5b06e 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -15,7 +15,7 @@ import ( "neo-code/internal/config" agentcontext "neo-code/internal/context" contextcompact "neo-code/internal/context/compact" - "neo-code/internal/context/repository" + "neo-code/internal/repository" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" approvalflow "neo-code/internal/runtime/approval" diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 6d54c81f..a27fcb52 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -50,6 +50,8 @@ type runState struct { hasUnknownUsage bool completion controlplane.CompletionState progress controlplane.ProgressState + lastEndOfTurnCheckpointID string + baselineCheckpointID string hookAnnotations []string hookNotifications []queuedHookNotification hookNotificationSeen map[string]time.Time diff --git a/internal/runtime/tool_diff_helpers_test.go b/internal/runtime/tool_diff_helpers_test.go index d015b674..cec450e6 100644 --- a/internal/runtime/tool_diff_helpers_test.go +++ b/internal/runtime/tool_diff_helpers_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "neo-code/internal/checkpoint" + "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" "neo-code/internal/tools" @@ -232,7 +232,7 @@ func TestToolExecutionHelperFunctions(t *testing.T) { }) t.Run("collectUncoveredBashPaths removes covered and duplicate entries", func(t *testing.T) { - diff := checkpoint.FingerprintDiff{ + diff := repository.FingerprintDiff{ Added: []string{"new.txt", "new.txt"}, Modified: []string{"tracked.txt", "covered.txt"}, } @@ -246,7 +246,7 @@ func TestToolExecutionHelperFunctions(t *testing.T) { }) t.Run("collectBashWriteFactPaths includes only added and modified paths", func(t *testing.T) { - got := collectBashWriteFactPaths(checkpoint.FingerprintDiff{ + got := collectBashWriteFactPaths(repository.FingerprintDiff{ Added: []string{"b.txt", "a.txt"}, Modified: []string{"a.txt", "c.txt"}, Deleted: []string{"old.txt"}, @@ -275,7 +275,7 @@ func TestEmitHelpersPublishExpectedEvents(t *testing.T) { state, providertypes.ToolCall{ID: "tool-1"}, "touch x", - checkpoint.FingerprintDiff{ + repository.FingerprintDiff{ Added: []string{"new.txt"}, Modified: []string{"edit.txt"}, Deleted: []string{"old.txt"}, @@ -301,7 +301,7 @@ func TestEmitHelpersPublishExpectedEvents(t *testing.T) { state, providertypes.ToolCall{ID: "tool-2"}, "touch noop", - checkpoint.FingerprintDiff{}, + repository.FingerprintDiff{}, nil, nil, ) diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index bbed3279..7374129a 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -10,6 +10,7 @@ import ( "sync" "neo-code/internal/checkpoint" + "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" runtimefacts "neo-code/internal/runtime/facts" runtimehooks "neo-code/internal/runtime/hooks" @@ -156,7 +157,7 @@ func (s *Service) executeOneToolCall( isBash := strings.EqualFold(strings.TrimSpace(call.Name), tools.ToolNameBash) var preSnaps map[string]fileSnapshot - var preFingerprint checkpoint.WorkdirFingerprint + var preFingerprint repository.WorkdirFingerprint var bashCapturedPaths []string var bashCommand string var bashChangedPaths []string @@ -196,7 +197,7 @@ func (s *Service) executeOneToolCall( if len(bashCapturedPaths) > 0 { _, _ = s.perEditStore.CaptureBatch(bashCapturedPaths) } - if fp, _, err := checkpoint.ScanWorkdir(ctx, snapshot.Workdir, checkpoint.DefaultFingerprintOptions()); err == nil { + if fp, _, err := repository.ScanWorkdir(ctx, snapshot.Workdir, repository.DefaultFingerprintOptions()); err == nil { preFingerprint = fp } } @@ -266,8 +267,8 @@ func (s *Service) executeOneToolCall( } if isBash && preFingerprint != nil && execErr == nil && !result.IsError { - if afterFP, _, err := checkpoint.ScanWorkdir(ctx, snapshot.Workdir, checkpoint.DefaultFingerprintOptions()); err == nil { - fpDiff := checkpoint.DiffFingerprints(preFingerprint, afterFP) + if afterFP, _, err := repository.ScanWorkdir(ctx, snapshot.Workdir, repository.DefaultFingerprintOptions()); err == nil { + fpDiff := repository.DiffFingerprints(preFingerprint, afterFP) if len(fpDiff.Added) > 0 || len(fpDiff.Modified) > 0 || len(fpDiff.Deleted) > 0 { bashChangedPaths = collectBashWriteFactPaths(fpDiff) covered := make(map[string]struct{}, len(bashCapturedPaths)) @@ -598,7 +599,7 @@ func bashCommandFromCall(call providertypes.ToolCall) string { } // collectBashWriteFactPaths 从 bash fingerprint diff 中提取可验证的新增/修改路径,删除路径不作为写后验收目标。 -func collectBashWriteFactPaths(fpDiff checkpoint.FingerprintDiff) []string { +func collectBashWriteFactPaths(fpDiff repository.FingerprintDiff) []string { seen := make(map[string]struct{}) out := make([]string, 0, len(fpDiff.Modified)+len(fpDiff.Added)) add := func(path string) { @@ -626,7 +627,7 @@ func collectBashWriteFactPaths(fpDiff checkpoint.FingerprintDiff) []string { // collectUncoveredBashPaths 把 fingerprint 检测到的变更路径与启发式预捕获集合做差, // 输出 EventBashSideEffect.UncoveredPaths 用于可观测性提醒。 -func collectUncoveredBashPaths(workdir string, fpDiff checkpoint.FingerprintDiff, covered map[string]struct{}) []string { +func collectUncoveredBashPaths(workdir string, fpDiff repository.FingerprintDiff, covered map[string]struct{}) []string { if len(fpDiff.Added) == 0 && len(fpDiff.Modified) == 0 { return nil } @@ -673,7 +674,7 @@ func (s *Service) emitBashSideEffectEvent( state *runState, call providertypes.ToolCall, command string, - fpDiff checkpoint.FingerprintDiff, + fpDiff repository.FingerprintDiff, preCaptured []string, uncovered []string, ) { diff --git a/internal/tools/codebase/common.go b/internal/tools/codebase/common.go new file mode 100644 index 00000000..0755387e --- /dev/null +++ b/internal/tools/codebase/common.go @@ -0,0 +1,30 @@ +package codebase + +func itoa(i int) string { + if i == 0 { + return "0" + } + negative := i < 0 + if negative { + i = -i + } + var buf [20]byte + bp := len(buf) + for i > 0 { + bp-- + buf[bp] = byte('0' + i%10) + i /= 10 + } + if negative { + bp-- + buf[bp] = '-' + } + return string(buf[bp:]) +} + +func boolToString(v bool) string { + if v { + return "true" + } + return "false" +} diff --git a/internal/tools/codebase/common_test.go b/internal/tools/codebase/common_test.go new file mode 100644 index 00000000..15203a8d --- /dev/null +++ b/internal/tools/codebase/common_test.go @@ -0,0 +1,45 @@ +package codebase + +import ( + "os" + "path/filepath" + "testing" + + "neo-code/internal/tools" +) + +func TestCodebaseCommonHelpers(t *testing.T) { + t.Parallel() + + root := t.TempDir() + child := filepath.Join(root, "subdir") + if err := os.Mkdir(child, 0o755); err != nil { + t.Fatalf("Mkdir() error = %v", err) + } + canonicalRoot, err := filepath.EvalSymlinks(root) + if err != nil { + t.Fatalf("EvalSymlinks(root) error = %v", err) + } + + if got, err := tools.ResolveEffectiveRoot(root, " "); err != nil || got != canonicalRoot { + t.Fatalf("effectiveRoot(default) = %q", got) + } + if got, err := tools.ResolveEffectiveRoot(root, "subdir"); err != nil || got != child { + t.Fatalf("effectiveRoot(custom) = %q", got) + } + if got := itoa(0); got != "0" { + t.Fatalf("itoa(0) = %q", got) + } + if got := itoa(-9); got != "-9" { + t.Fatalf("itoa(-9) = %q", got) + } + if got := boolToString(true); got != "true" { + t.Fatalf("boolToString(true) = %q", got) + } + if got := boolToString(false); got != "false" { + t.Fatalf("boolToString(false) = %q", got) + } + if _, err := tools.ResolveEffectiveRoot(root, "../escape"); err == nil { + t.Fatal("ResolveEffectiveRoot should reject escaping workdir") + } +} diff --git a/internal/tools/codebase/read.go b/internal/tools/codebase/read.go new file mode 100644 index 00000000..f76476e1 --- /dev/null +++ b/internal/tools/codebase/read.go @@ -0,0 +1,110 @@ +package codebase + +import ( + "context" + "encoding/json" + "strings" + + "neo-code/internal/repository" + "neo-code/internal/tools" +) + +// ReadTool implements the codebase_read tool. +type ReadTool struct { + root string + svc *repository.Service +} + +// NewRead creates a new codebase_read tool. +func NewRead(svc *repository.Service, root string) *ReadTool { + return &ReadTool{root: root, svc: svc} +} + +func (t *ReadTool) Name() string { + return tools.ToolNameCodebaseRead +} + +func (t *ReadTool) Description() string { + return "Read the content of a file within the workspace. Use this when you need to see implementation details after locating a file via search tools." +} + +func (t *ReadTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Relative path to the file within the workspace.", + }, + "workdir": map[string]any{ + "type": "string", + "description": "Optional working directory relative to the workspace root.", + }, + "max_bytes": map[string]any{ + "type": "integer", + "description": "Maximum bytes to read (default 256KB).", + }, + }, + "required": []string{"path"}, + } +} + +func (t *ReadTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyPreserveHistory +} + +func (t *ReadTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { + var in struct { + Path string `json:"path"` + Workdir string `json:"workdir,omitempty"` + MaxBytes int `json:"max_bytes,omitempty"` + } + if err := json.Unmarshal(call.Arguments, &in); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(in.Path) == "" { + err := &json.UnmarshalTypeError{} + return tools.NewErrorResult(t.Name(), "missing required argument: path", "", nil), err + } + + root, err := tools.ResolveEffectiveRoot(t.root, in.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } + result, err := t.svc.Read(ctx, root, in.Path, repository.ReadOptions{MaxBytes: in.MaxBytes}) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + content := formatReadResult(result) + return tools.ToolResult{ + Name: t.Name(), + Content: content, + Metadata: map[string]any{ + "path": result.Path, + "truncated": result.Truncated, + "is_binary": result.IsBinary, + "size": result.Size, + }, + }, nil +} + +func formatReadResult(r repository.ReadResult) string { + if r.Path == "" && r.Content == "" { + return "file not found or access denied" + } + var b strings.Builder + b.WriteString("path: ") + b.WriteString(r.Path) + b.WriteString("\nis_binary: ") + b.WriteString(boolToString(r.IsBinary)) + b.WriteString("\ntruncated: ") + b.WriteString(boolToString(r.Truncated)) + b.WriteString("\nsize: ") + b.WriteString(itoa(int(r.Size))) + if !r.IsBinary && r.Content != "" { + b.WriteString("\n\n") + b.WriteString(r.Content) + } + return b.String() +} diff --git a/internal/tools/codebase/read_test.go b/internal/tools/codebase/read_test.go new file mode 100644 index 00000000..ba29f03b --- /dev/null +++ b/internal/tools/codebase/read_test.go @@ -0,0 +1,155 @@ +package codebase + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/repository" + "neo-code/internal/tools" +) + +func TestReadToolMetadata(t *testing.T) { + t.Parallel() + + tool := NewRead(repository.NewService(), "/workspace") + if tool.Name() != "codebase_read" { + t.Fatalf("Name() = %q, want %q", tool.Name(), "codebase_read") + } + if tool.Description() == "" { + t.Fatalf("Description() should not be empty") + } + schema := tool.Schema() + if schema == nil { + t.Fatalf("Schema() should not be nil") + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatalf("Schema properties should be a map, got %T", schema["properties"]) + } + if _, hasPath := props["path"]; !hasPath { + t.Fatalf("Schema should have path property") + } + if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { + t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) + } +} + +func TestReadToolInvalidJSON(t *testing.T) { + t.Parallel() + + tool := NewRead(repository.NewService(), "/workspace") + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + }) + if err == nil { + t.Fatalf("expected error for invalid JSON, got result: %+v", result) + } + if !result.IsError { + t.Fatalf("expected IsError result") + } +} + +func TestReadToolMissingPath(t *testing.T) { + t.Parallel() + + tool := NewRead(repository.NewService(), "/workspace") + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{}), + }) + if err == nil { + t.Fatalf("expected error for missing path") + } + if !result.IsError { + t.Fatalf("expected IsError result") + } +} + +func TestReadToolFileNotFound(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + tool := NewRead(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"path": "nonexistent.go"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "file not found") { + t.Fatalf("expected 'file not found' message, got %q", result.Content) + } +} + +func TestReadToolReadsFileContent(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + content := "package main\nfunc main() {}\n" + if err := os.WriteFile(filepath.Join(workspace, "main.go"), []byte(content), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewRead(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"path": "main.go"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "package main") { + t.Fatalf("expected file content, got %q", result.Content) + } + metaPath, _ := result.Metadata["path"].(string) + if metaPath != "main.go" { + t.Fatalf("expected metadata path 'main.go', got %v", result.Metadata["path"]) + } +} + +func TestReadToolPathTraversalRejected(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + tool := NewRead(repository.NewService(), workspace) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"path": "../../etc/passwd"}), + }) + if err == nil { + t.Fatalf("expected error for path traversal") + } + if !strings.Contains(err.Error(), "escapes workspace root") { + t.Fatalf("expected 'escapes workspace root' error, got %v", err) + } +} + +func TestReadToolFileInSubdirectory(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + subdir := filepath.Join(workspace, "sub") + if err := os.Mkdir(subdir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(subdir, "nested.go"), []byte("nested content"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewRead(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"path": "sub/nested.go"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "nested content") { + t.Fatalf("expected file content, got %q", result.Content) + } +} diff --git a/internal/tools/codebase/searchsymbol.go b/internal/tools/codebase/searchsymbol.go new file mode 100644 index 00000000..ccf22101 --- /dev/null +++ b/internal/tools/codebase/searchsymbol.go @@ -0,0 +1,121 @@ +package codebase + +import ( + "context" + "encoding/json" + "strings" + + "neo-code/internal/repository" + "neo-code/internal/tools" +) + +// SearchSymbolTool implements the codebase_search_symbol tool. +type SearchSymbolTool struct { + root string + svc *repository.Service +} + +// NewSearchSymbol creates a new codebase_search_symbol tool. +func NewSearchSymbol(svc *repository.Service, root string) *SearchSymbolTool { + return &SearchSymbolTool{root: root, svc: svc} +} + +func (t *SearchSymbolTool) Name() string { + return tools.ToolNameCodebaseSearchSymbol +} + +func (t *SearchSymbolTool) Description() string { + return "Search for symbol definitions across the workspace. Returns file paths, line hints, kind (function/type/method/etc.), and signature. Does NOT return the function body; use codebase_read to view implementation." +} + +func (t *SearchSymbolTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "symbol": map[string]any{ + "type": "string", + "description": "Symbol name to search for.", + }, + "scope_dir": map[string]any{ + "type": "string", + "description": "Optional subdirectory to limit the search scope.", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of hits to return (default 20, max 50).", + }, + "workdir": map[string]any{ + "type": "string", + "description": "Optional working directory relative to the workspace root.", + }, + }, + "required": []string{"symbol"}, + } +} + +func (t *SearchSymbolTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *SearchSymbolTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { + var in struct { + Symbol string `json:"symbol"` + ScopeDir string `json:"scope_dir,omitempty"` + Limit int `json:"limit,omitempty"` + Workdir string `json:"workdir,omitempty"` + } + if err := json.Unmarshal(call.Arguments, &in); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(in.Symbol) == "" { + return tools.NewErrorResult(t.Name(), "missing required argument: symbol", "", nil), nil + } + + root, err := tools.ResolveEffectiveRoot(t.root, in.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } + opts := repository.SearchOptions{ + ScopeDir: in.ScopeDir, + Limit: in.Limit, + } + result, err := t.svc.SearchSymbol(ctx, root, in.Symbol, opts) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + content := formatSymbolSearchResult(result) + return tools.ToolResult{ + Name: t.Name(), + Content: content, + Metadata: map[string]any{ + "returned_count": len(result.Hits), + "total_count": result.TotalCount, + "truncated": result.Truncated, + }, + }, nil +} + +func formatSymbolSearchResult(r repository.SymbolSearchResult) string { + var b strings.Builder + b.WriteString("returned_count: ") + b.WriteString(itoa(len(r.Hits))) + b.WriteString("\ntotal_count: ") + b.WriteString(itoa(r.TotalCount)) + b.WriteString("\ntruncated: ") + b.WriteString(boolToString(r.Truncated)) + if len(r.Hits) > 0 { + b.WriteString("\n") + } + for _, h := range r.Hits { + b.WriteString("\n- path: ") + b.WriteString(h.Path) + b.WriteString("\n line_hint: ") + b.WriteString(itoa(h.LineHint)) + b.WriteString("\n kind: ") + b.WriteString(h.Kind) + b.WriteString("\n signature: ") + b.WriteString(h.Signature) + } + return b.String() +} diff --git a/internal/tools/codebase/searchsymbol_test.go b/internal/tools/codebase/searchsymbol_test.go new file mode 100644 index 00000000..33f792be --- /dev/null +++ b/internal/tools/codebase/searchsymbol_test.go @@ -0,0 +1,163 @@ +package codebase + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/repository" + "neo-code/internal/tools" +) + +func TestSearchSymbolToolMetadata(t *testing.T) { + t.Parallel() + + tool := NewSearchSymbol(repository.NewService(), "/workspace") + if tool.Name() != "codebase_search_symbol" { + t.Fatalf("Name() = %q, want %q", tool.Name(), "codebase_search_symbol") + } + if tool.Description() == "" { + t.Fatalf("Description() should not be empty") + } + schema := tool.Schema() + if schema == nil { + t.Fatalf("Schema() should not be nil") + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatalf("Schema properties should be a map") + } + if _, hasSymbol := props["symbol"]; !hasSymbol { + t.Fatalf("Schema should have symbol property") + } + if tool.MicroCompactPolicy() != tools.MicroCompactPolicyCompact { + t.Fatalf("MicroCompactPolicy() = %v, want Compact", tool.MicroCompactPolicy()) + } +} + +func TestSearchSymbolToolInvalidJSON(t *testing.T) { + t.Parallel() + + tool := NewSearchSymbol(repository.NewService(), "/workspace") + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + }) + if err == nil { + t.Fatalf("expected error for invalid JSON, got result: %+v", result) + } + if !result.IsError { + t.Fatalf("expected IsError result") + } +} + +func TestSearchSymbolToolMissingSymbol(t *testing.T) { + t.Parallel() + + tool := NewSearchSymbol(repository.NewService(), "/workspace") + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{}), + }) + if err != nil { + t.Fatalf("expected no error for missing symbol, got %v", err) + } + if !result.IsError { + t.Fatalf("expected IsError result") + } + if !strings.Contains(result.Content, "missing required argument") { + t.Fatalf("expected missing argument message, got %q", result.Content) + } +} + +func TestSearchSymbolToolFindsGoSymbol(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `package main + +func Hello(name string) string { + return "hi " + name +} + +type MyStruct struct { + Field int +} +` + if err := os.WriteFile(filepath.Join(workspace, "main.go"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewSearchSymbol(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"symbol": "Hello"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "returned_count: 1") { + t.Fatalf("expected 1 hit, got %q", result.Content) + } + if !strings.Contains(result.Content, "Hello") { + t.Fatalf("expected 'Hello' in result, got %q", result.Content) + } + if !strings.Contains(result.Content, "function") { + t.Fatalf("expected 'function' kind, got %q", result.Content) + } +} + +func TestSearchSymbolToolNoResults(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + if err := os.WriteFile(filepath.Join(workspace, "main.go"), []byte("package main\nfunc foo() {}\n"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewSearchSymbol(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"symbol": "NonExistent"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "returned_count: 0") { + t.Fatalf("expected 0 hits, got %q", result.Content) + } +} + +func TestSearchSymbolToolReturnsSignatureOnly(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + src := `package main + +func longFunction(a string, b int, c float64) string { + // This is a very long function body that should not be included + return "result" +} +` + if err := os.WriteFile(filepath.Join(workspace, "main.go"), []byte(src), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewSearchSymbol(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"symbol": "longFunction"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should contain the signature but NOT the body + if !strings.Contains(result.Content, "func longFunction") { + t.Fatalf("expected signature, got %q", result.Content) + } + if strings.Contains(result.Content, "return \"result\"") { + t.Fatalf("expected signature only, but got function body in %q", result.Content) + } +} diff --git a/internal/tools/codebase/searchtext.go b/internal/tools/codebase/searchtext.go new file mode 100644 index 00000000..dda70073 --- /dev/null +++ b/internal/tools/codebase/searchtext.go @@ -0,0 +1,119 @@ +package codebase + +import ( + "context" + "encoding/json" + "strings" + + "neo-code/internal/repository" + "neo-code/internal/tools" +) + +// SearchTextTool implements the codebase_search_text tool. +type SearchTextTool struct { + root string + svc *repository.Service +} + +// NewSearchText creates a new codebase_search_text tool. +func NewSearchText(svc *repository.Service, root string) *SearchTextTool { + return &SearchTextTool{root: root, svc: svc} +} + +func (t *SearchTextTool) Name() string { + return tools.ToolNameCodebaseSearchText +} + +func (t *SearchTextTool) Description() string { + return "Search for text occurrences across the workspace. Returns file paths, line hints, and match counts. Does NOT return code snippets; use codebase_read to view content." +} + +func (t *SearchTextTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Text to search for.", + }, + "scope_dir": map[string]any{ + "type": "string", + "description": "Optional subdirectory to limit the search scope.", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of hits to return (default 20, max 50).", + }, + "workdir": map[string]any{ + "type": "string", + "description": "Optional working directory relative to the workspace root.", + }, + }, + "required": []string{"query"}, + } +} + +func (t *SearchTextTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *SearchTextTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { + var in struct { + Query string `json:"query"` + ScopeDir string `json:"scope_dir,omitempty"` + Limit int `json:"limit,omitempty"` + Workdir string `json:"workdir,omitempty"` + } + if err := json.Unmarshal(call.Arguments, &in); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(in.Query) == "" { + return tools.NewErrorResult(t.Name(), "missing required argument: query", "", nil), nil + } + + root, err := tools.ResolveEffectiveRoot(t.root, in.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } + opts := repository.SearchOptions{ + ScopeDir: in.ScopeDir, + Limit: in.Limit, + } + result, err := t.svc.SearchText(ctx, root, in.Query, opts) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + content := formatTextSearchResult(result) + return tools.ToolResult{ + Name: t.Name(), + Content: content, + Metadata: map[string]any{ + "returned_count": len(result.Hits), + "total_count": result.TotalCount, + "truncated": result.Truncated, + }, + }, nil +} + +func formatTextSearchResult(r repository.TextSearchResult) string { + var b strings.Builder + b.WriteString("returned_count: ") + b.WriteString(itoa(len(r.Hits))) + b.WriteString("\ntotal_count: ") + b.WriteString(itoa(r.TotalCount)) + b.WriteString("\ntruncated: ") + b.WriteString(boolToString(r.Truncated)) + if len(r.Hits) > 0 { + b.WriteString("\n") + } + for _, h := range r.Hits { + b.WriteString("\n- path: ") + b.WriteString(h.Path) + b.WriteString("\n line_hint: ") + b.WriteString(itoa(h.LineHint)) + b.WriteString("\n match_count: ") + b.WriteString(itoa(h.MatchCount)) + } + return b.String() +} diff --git a/internal/tools/codebase/searchtext_test.go b/internal/tools/codebase/searchtext_test.go new file mode 100644 index 00000000..06741a6b --- /dev/null +++ b/internal/tools/codebase/searchtext_test.go @@ -0,0 +1,163 @@ +package codebase + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/repository" + "neo-code/internal/tools" +) + +func TestSearchTextToolMetadata(t *testing.T) { + t.Parallel() + + tool := NewSearchText(repository.NewService(), "/workspace") + if tool.Name() != "codebase_search_text" { + t.Fatalf("Name() = %q, want %q", tool.Name(), "codebase_search_text") + } + if tool.Description() == "" { + t.Fatalf("Description() should not be empty") + } + schema := tool.Schema() + if schema == nil { + t.Fatalf("Schema() should not be nil") + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatalf("Schema properties should be a map") + } + if _, hasQuery := props["query"]; !hasQuery { + t.Fatalf("Schema should have query property") + } + if tool.MicroCompactPolicy() != tools.MicroCompactPolicyCompact { + t.Fatalf("MicroCompactPolicy() = %v, want Compact", tool.MicroCompactPolicy()) + } +} + +func TestSearchTextToolInvalidJSON(t *testing.T) { + t.Parallel() + + tool := NewSearchText(repository.NewService(), "/workspace") + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + }) + if err == nil { + t.Fatalf("expected error for invalid JSON, got result: %+v", result) + } + if !result.IsError { + t.Fatalf("expected IsError result") + } +} + +func TestSearchTextToolMissingQuery(t *testing.T) { + t.Parallel() + + tool := NewSearchText(repository.NewService(), "/workspace") + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{}), + }) + if err != nil { + t.Fatalf("expected no error for missing query, got %v", err) + } + if !result.IsError { + t.Fatalf("expected IsError result") + } + if !strings.Contains(result.Content, "missing required argument") { + t.Fatalf("expected missing argument message, got %q", result.Content) + } +} + +func TestSearchTextToolFindsMatches(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + if err := os.WriteFile(filepath.Join(workspace, "test.go"), []byte("func Hello() {\n\treturn \"hello\"\n}\n"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + if err := os.WriteFile(filepath.Join(workspace, "other.go"), []byte("package main\n"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewSearchText(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"query": "Hello"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "returned_count: 1") { + t.Fatalf("expected 1 hit, got %q", result.Content) + } + if !strings.Contains(result.Content, "test.go") { + t.Fatalf("expected 'test.go' in result, got %q", result.Content) + } +} + +func TestSearchTextToolNoMatches(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + if err := os.WriteFile(filepath.Join(workspace, "test.go"), []byte("func Hello() {}\n"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewSearchText(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"query": "NonExistentSymbol"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result.Content, "returned_count: 0") { + t.Fatalf("expected 0 hits, got %q", result.Content) + } +} + +func TestSearchTextToolRespectsScopeDir(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + if err := os.WriteFile(filepath.Join(workspace, "root.go"), []byte("found"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + sub := filepath.Join(workspace, "sub") + if err := os.Mkdir(sub, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(sub, "nested.go"), []byte("found"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + tool := NewSearchText(repository.NewService(), workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: mustArgs(t, map[string]any{"query": "found", "scope_dir": "sub"}), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should only find nested.go, not root.go + if strings.Contains(result.Content, "root.go") { + t.Fatalf("expected scope_dir to limit results, got root.go in %q", result.Content) + } + if !strings.Contains(result.Content, "nested.go") { + t.Fatalf("expected nested.go in scoped results, got %q", result.Content) + } +} + +func mustArgs(t *testing.T, v map[string]any) []byte { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return b +} diff --git a/internal/tools/filesystem/copy_file.go b/internal/tools/filesystem/copy_file.go index c03b6111..5ba62aea 100644 --- a/internal/tools/filesystem/copy_file.go +++ b/internal/tools/filesystem/copy_file.go @@ -75,7 +75,10 @@ func (t *CopyFileTool) Execute(ctx context.Context, input tools.ToolCallInput) ( return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } src, err := resolvePath(base, args.SourcePath) if err != nil { diff --git a/internal/tools/filesystem/create_dir.go b/internal/tools/filesystem/create_dir.go index 538ebb03..22d0f8de 100644 --- a/internal/tools/filesystem/create_dir.go +++ b/internal/tools/filesystem/create_dir.go @@ -71,7 +71,10 @@ func (t *CreateDirTool) Execute(ctx context.Context, input tools.ToolCallInput) recursive = *args.Recursive } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) if err != nil { diff --git a/internal/tools/filesystem/delete_file.go b/internal/tools/filesystem/delete_file.go index 5415c9ca..32aba86d 100644 --- a/internal/tools/filesystem/delete_file.go +++ b/internal/tools/filesystem/delete_file.go @@ -61,7 +61,10 @@ func (t *DeleteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) if err != nil { diff --git a/internal/tools/filesystem/edit.go b/internal/tools/filesystem/edit.go index 631e1082..2a07b6db 100644 --- a/internal/tools/filesystem/edit.go +++ b/internal/tools/filesystem/edit.go @@ -77,7 +77,10 @@ func (t *EditTool) Execute(ctx context.Context, input tools.ToolCallInput) (tool return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - root := effectiveRoot(t.root, input.Workdir) + root, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } root, target, err := tools.ResolveWorkspaceTarget( input, security.TargetTypePath, @@ -127,5 +130,6 @@ func (t *EditTool) Execute(ctx context.Context, input tools.ToolCallInput) (tool "search_length": len(args.SearchString), "replacement_length": len(args.ReplaceString), }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, }, nil } diff --git a/internal/tools/filesystem/edit_test.go b/internal/tools/filesystem/edit_test.go index d8c499e4..76f04463 100644 --- a/internal/tools/filesystem/edit_test.go +++ b/internal/tools/filesystem/edit_test.go @@ -105,6 +105,9 @@ func TestEditToolExecute(t *testing.T) { if result.Content != "ok" { t.Fatalf("expected result content ok, got %q", result.Content) } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true for successful edit, got false") + } data, err := os.ReadFile(filepath.Join(workspace, tt.path)) if err != nil { diff --git a/internal/tools/filesystem/glob.go b/internal/tools/filesystem/glob.go index 7b4067f7..3579ff8d 100644 --- a/internal/tools/filesystem/glob.go +++ b/internal/tools/filesystem/glob.go @@ -82,7 +82,10 @@ func (t *GlobTool) Execute(ctx context.Context, input tools.ToolCallInput) (tool return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - root := effectiveRoot(t.root, input.Workdir) + root, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } searchRoot, err := resolveSearchDir(root, args.Dir) if err != nil { return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err diff --git a/internal/tools/filesystem/grep.go b/internal/tools/filesystem/grep.go index 5ef5c8fc..7bf1deca 100644 --- a/internal/tools/filesystem/grep.go +++ b/internal/tools/filesystem/grep.go @@ -80,7 +80,10 @@ func (t *GrepTool) Execute(ctx context.Context, input tools.ToolCallInput) (tool return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - root := effectiveRoot(t.root, input.Workdir) + root, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } searchRoot, err := resolveSearchDir(root, args.Dir) if err != nil { return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err diff --git a/internal/tools/filesystem/helpers.go b/internal/tools/filesystem/helpers.go index 4606cb27..5296914f 100644 --- a/internal/tools/filesystem/helpers.go +++ b/internal/tools/filesystem/helpers.go @@ -21,14 +21,6 @@ const ( removeDirToolName = tools.ToolNameFilesystemRemoveDir ) -func effectiveRoot(defaultRoot string, workdir string) string { - base := strings.TrimSpace(workdir) - if base == "" { - base = defaultRoot - } - return base -} - func toRelativePath(root string, target string) string { base, err := filepath.Abs(root) if err != nil { diff --git a/internal/tools/filesystem/metadata_test.go b/internal/tools/filesystem/metadata_test.go index 58edb373..04cb0e1c 100644 --- a/internal/tools/filesystem/metadata_test.go +++ b/internal/tools/filesystem/metadata_test.go @@ -31,12 +31,6 @@ func TestToolMetadataAndHelpers(t *testing.T) { } } - if effectiveRoot("", root) != root { - t.Fatalf("expected workdir fallback") - } - if got := effectiveRoot(root, ""); got != root { - t.Fatalf("expected default root, got %q", got) - } if rel := toRelativePath(root, root); rel != "." { t.Fatalf("expected relative root path '.', got %q", rel) } diff --git a/internal/tools/filesystem/move_file.go b/internal/tools/filesystem/move_file.go index c627b85f..6ed7121a 100644 --- a/internal/tools/filesystem/move_file.go +++ b/internal/tools/filesystem/move_file.go @@ -76,7 +76,10 @@ func (t *MoveFileTool) Execute(ctx context.Context, input tools.ToolCallInput) ( return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } src, err := resolvePath(base, args.SourcePath) if err != nil { diff --git a/internal/tools/filesystem/move_file_test.go b/internal/tools/filesystem/move_file_test.go index bd95e86d..b76e8cc8 100644 --- a/internal/tools/filesystem/move_file_test.go +++ b/internal/tools/filesystem/move_file_test.go @@ -47,7 +47,7 @@ func TestMoveFileTool_RenamesWithinWorkspace(t *testing.T) { } else if string(data) != "hello" { t.Fatalf("dst content = %q want hello", string(data)) } - if got := result.Metadata["destination_path"]; got != dst { + if got, ok := result.Metadata["destination_path"].(string); !ok || !strings.EqualFold(got, dst) { t.Fatalf("destination_path metadata = %v want %v", got, dst) } paths, ok := result.Metadata["paths"].([]string) diff --git a/internal/tools/filesystem/read_file.go b/internal/tools/filesystem/read_file.go index ca01d2ea..c789fdd2 100644 --- a/internal/tools/filesystem/read_file.go +++ b/internal/tools/filesystem/read_file.go @@ -78,7 +78,10 @@ func (t *ReadFileTool) Execute(ctx context.Context, input tools.ToolCallInput) ( return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } base, target, err := tools.ResolveWorkspaceTarget( input, diff --git a/internal/tools/filesystem/remove_dir.go b/internal/tools/filesystem/remove_dir.go index cc611b6a..ae0be3bb 100644 --- a/internal/tools/filesystem/remove_dir.go +++ b/internal/tools/filesystem/remove_dir.go @@ -66,7 +66,10 @@ func (t *RemoveDirTool) Execute(ctx context.Context, input tools.ToolCallInput) return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) if err != nil { diff --git a/internal/tools/filesystem/write_file.go b/internal/tools/filesystem/write_file.go index 2df6b43d..1185fb25 100644 --- a/internal/tools/filesystem/write_file.go +++ b/internal/tools/filesystem/write_file.go @@ -79,7 +79,10 @@ func (t *WriteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - base := effectiveRoot(t.root, input.Workdir) + base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) + if err != nil { + return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err + } _, target, err := tools.ResolveWorkspaceTarget( input, @@ -138,6 +141,7 @@ func (t *WriteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) "noop_write": false, "content_unchanged": false, }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, } if token, ok := compactWriteVerificationToken(args.Content); ok { result.Metadata["written_content"] = token diff --git a/internal/tools/filesystem/write_file_test.go b/internal/tools/filesystem/write_file_test.go index 597a2755..7481882e 100644 --- a/internal/tools/filesystem/write_file_test.go +++ b/internal/tools/filesystem/write_file_test.go @@ -106,6 +106,9 @@ func TestWriteFileToolMetadataAndExecute(t *testing.T) { if result.Content != "ok" { t.Fatalf("expected ok result, got %q", result.Content) } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true for fresh write, got false") + } data, err := os.ReadFile(tt.expectPath) if err != nil { @@ -178,6 +181,9 @@ func TestWriteFileToolNoopWriteMetadata(t *testing.T) { if !ok || !unchanged { t.Fatalf("content_unchanged metadata = %#v, want true", result.Metadata["content_unchanged"]) } + if result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=false for noop write, got true") + } } func TestWriteFileToolVerifyAfterWriteFacts(t *testing.T) { @@ -206,6 +212,9 @@ func TestWriteFileToolVerifyAfterWriteFacts(t *testing.T) { if !result.Facts.VerificationPerformed || !result.Facts.VerificationPassed { t.Fatalf("verification facts = %+v, want performed=true passed=true", result.Facts) } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true for fresh write, got false") + } if token, _ := result.Metadata["written_content"].(string); token != "verified-content" { t.Fatalf("written_content = %#v, want verified-content", result.Metadata["written_content"]) } @@ -242,6 +251,9 @@ func TestWriteFileToolVerifyAfterWriteFacts(t *testing.T) { if !result.Facts.VerificationPerformed || !result.Facts.VerificationPassed { t.Fatalf("verification facts = %+v, want performed=true passed=true", result.Facts) } + if result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=false for noop write, got true") + } if result.Facts.VerificationScope != "artifact:same-verified.txt" { t.Fatalf("verification scope = %q", result.Facts.VerificationScope) } diff --git a/internal/tools/names.go b/internal/tools/names.go index 625a1dc6..47c09bec 100644 --- a/internal/tools/names.go +++ b/internal/tools/names.go @@ -21,4 +21,8 @@ const ( ToolNameMemoList = "memo_list" ToolNameMemoRemove = "memo_remove" ToolNameDiagnose = "diagnose" + + ToolNameCodebaseRead = "codebase_read" + ToolNameCodebaseSearchText = "codebase_search_text" + ToolNameCodebaseSearchSymbol = "codebase_search_symbol" ) diff --git a/internal/tools/workspace_plan.go b/internal/tools/workspace_plan.go index e9ec2aad..bc34024f 100644 --- a/internal/tools/workspace_plan.go +++ b/internal/tools/workspace_plan.go @@ -9,6 +9,17 @@ import ( "neo-code/internal/security" ) +// ResolveEffectiveRoot resolves a user-supplied workdir against the configured +// workspace root and ensures the result stays within the workspace boundary. +func ResolveEffectiveRoot(defaultRoot string, workdir string) (string, error) { + base := strings.TrimSpace(workdir) + if base == "" { + return defaultRoot, nil + } + _, resolved, err := security.ResolveWorkspacePath(defaultRoot, base) + return resolved, err +} + type workspaceResolver func(root string, requested string) (string, error) // ResolveWorkspaceTarget resolves the effective execution target for one tool diff --git a/scripts/generate_gateway_rpc_examples/main.go b/scripts/generate_gateway_rpc_examples/main.go index f59e3e57..39585c8b 100644 --- a/scripts/generate_gateway_rpc_examples/main.go +++ b/scripts/generate_gateway_rpc_examples/main.go @@ -236,7 +236,7 @@ func buildMethodExamples() ([]methodExample, error) { SavedRatio: 0.63, TriggerMode: "manual", TranscriptID: "compact-demo-1", - TranscriptPath: ".neocode/transcripts/compact-demo-1.md", + TranscriptPath: ".neocode/transcripts/compact-demo-subagent.md", }, }) compactFailure := buildFailureResponse(