diff --git a/cmd/trace/cli/agent/codex/lifecycle.go b/cmd/trace/cli/agent/codex/lifecycle.go index d8d6d8d..2895f8a 100644 --- a/cmd/trace/cli/agent/codex/lifecycle.go +++ b/cmd/trace/cli/agent/codex/lifecycle.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "strings" "time" "github.com/GrayCodeAI/trace/cmd/trace/cli/agent" @@ -35,6 +36,7 @@ const ( HookNameUserPromptSubmit = "user-prompt-submit" HookNameStop = "stop" HookNamePreToolUse = "pre-tool-use" + HookNamePostToolUse = "post-tool-use" ) // HookNames returns the hook verbs Codex supports. @@ -44,6 +46,7 @@ func (c *CodexAgent) HookNames() []string { HookNameUserPromptSubmit, HookNameStop, HookNamePreToolUse, + HookNamePostToolUse, } } @@ -60,6 +63,8 @@ func (c *CodexAgent) ParseHookEvent(_ context.Context, hookName string, stdin io case HookNamePreToolUse: // PreToolUse has no lifecycle significance — pass through return nil, nil //nolint:nilnil // nil event = no lifecycle action + case HookNamePostToolUse: + return c.parsePostToolUse(stdin) default: return nil, nil //nolint:nilnil // Unknown hooks have no lifecycle action } @@ -107,3 +112,64 @@ func (c *CodexAgent) parseTurnEnd(stdin io.Reader) (*agent.Event, error) { Timestamp: time.Now(), }, nil } + +func (c *CodexAgent) parsePostToolUse(stdin io.Reader) (*agent.Event, error) { + raw, err := agent.ReadAndParseHookInput[postToolUseRaw](stdin) + if err != nil { + return nil, err + } + + // Only apply_patch carries file changes worth tracking. + if raw.ToolName != "apply_patch" { + return nil, nil //nolint:nilnil // non-mutating tools have no lifecycle action + } + + var input applyPatchInput + if err := json.Unmarshal(raw.ToolInput, &input); err != nil { + return nil, fmt.Errorf("failed to parse apply_patch input: %w", err) + } + + added, updated, deleted := parseApplyPatchFiles(input.Patch) + if len(added) == 0 && len(updated) == 0 && len(deleted) == 0 { + return nil, nil //nolint:nilnil // empty patch has no lifecycle action + } + + return &agent.Event{ + Type: agent.ToolUse, + SessionID: raw.SessionID, + SessionRef: derefString(raw.TranscriptPath), + ToolName: raw.ToolName, + ToolUseID: raw.ToolUseID, + ModifiedFiles: updated, + NewFiles: added, + DeletedFiles: deleted, + Timestamp: time.Now(), + }, nil +} + +// parseApplyPatchFiles extracts file paths from a Codex apply_patch envelope. +// The patch format uses markers: +// +// *** Add File: path +// *** Update File: path +// *** Delete File: path +func parseApplyPatchFiles(patch string) (added, updated, deleted []string) { + for line := range strings.SplitSeq(patch, "\n") { + line = strings.TrimSpace(line) + switch { + case strings.HasPrefix(line, "*** Add File:"): + if p := strings.TrimSpace(strings.TrimPrefix(line, "*** Add File:")); p != "" { + added = append(added, p) + } + case strings.HasPrefix(line, "*** Update File:"): + if p := strings.TrimSpace(strings.TrimPrefix(line, "*** Update File:")); p != "" { + updated = append(updated, p) + } + case strings.HasPrefix(line, "*** Delete File:"): + if p := strings.TrimSpace(strings.TrimPrefix(line, "*** Delete File:")); p != "" { + deleted = append(deleted, p) + } + } + } + return added, updated, deleted +} diff --git a/cmd/trace/cli/agent/codex/lifecycle_test.go b/cmd/trace/cli/agent/codex/lifecycle_test.go index ee27637..5346f16 100644 --- a/cmd/trace/cli/agent/codex/lifecycle_test.go +++ b/cmd/trace/cli/agent/codex/lifecycle_test.go @@ -130,3 +130,119 @@ func TestParseHookEvent_MalformedJSON_ReturnsError(t *testing.T) { _, err := ag.ParseHookEvent(context.Background(), HookNameSessionStart, strings.NewReader("{invalid json")) require.Error(t, err) } + +func TestParseHookEvent_PostToolUse_ApplyPatch(t *testing.T) { + t.Parallel() + ag := &CodexAgent{} + input := `{ + "session_id": "test-uuid", + "turn_id": "turn-1", + "transcript_path": null, + "cwd": "/tmp/repo", + "hook_event_name": "PostToolUse", + "model": "gpt-5", + "permission_mode": "default", + "tool_name": "apply_patch", + "tool_use_id": "call-patch", + "tool_input": {"patch": "*** Add File: a.go\n+hello\n*** Update File: b.go\n@@\n-old\n+new\n*** Delete File: c.go\n*** End Patch\n"}, + "tool_response": "Patch applied successfully." + }` + + event, err := ag.ParseHookEvent(context.Background(), HookNamePostToolUse, strings.NewReader(input)) + require.NoError(t, err) + require.NotNil(t, event) + require.Equal(t, agent.ToolUse, event.Type) + require.Equal(t, "test-uuid", event.SessionID) + require.Equal(t, "apply_patch", event.ToolName) + require.Equal(t, []string{"a.go"}, event.NewFiles) + require.Equal(t, []string{"b.go"}, event.ModifiedFiles) + require.Equal(t, []string{"c.go"}, event.DeletedFiles) +} + +func TestParseHookEvent_PostToolUse_NonApplyPatch_ReturnsNil(t *testing.T) { + t.Parallel() + ag := &CodexAgent{} + input := `{ + "session_id": "test-uuid", + "turn_id": "turn-1", + "transcript_path": null, + "cwd": "/tmp/repo", + "hook_event_name": "PostToolUse", + "model": "gpt-5", + "permission_mode": "default", + "tool_name": "shell", + "tool_use_id": "call-shell", + "tool_input": {"command": ["echo", "hi"]}, + "tool_response": "hi\n" + }` + + event, err := ag.ParseHookEvent(context.Background(), HookNamePostToolUse, strings.NewReader(input)) + require.NoError(t, err) + require.Nil(t, event) +} + +func TestParseApplyPatchFiles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + patch string + wantAdded []string + wantUpdated []string + wantDeleted []string + }{ + { + name: "all three operations", + patch: "*** Begin Patch\n" + + "*** Add File: docs/added.md\n" + + "+# added\n" + + "*** Update File: src/changed.go\n" + + "@@\n" + + "-old\n" + + "+new\n" + + "*** Delete File: tmp/gone.txt\n" + + "*** End Patch\n", + wantAdded: []string{"docs/added.md"}, + wantUpdated: []string{"src/changed.go"}, + wantDeleted: []string{"tmp/gone.txt"}, + }, + { + name: "empty patch", + patch: "", + wantAdded: nil, + wantUpdated: nil, + wantDeleted: nil, + }, + { + name: "only adds", + patch: "*** Add File: a.go\n" + + "+line1\n" + + "*** Add File: b.go\n" + + "+line2\n", + wantAdded: []string{"a.go", "b.go"}, + wantUpdated: nil, + wantDeleted: nil, + }, + { + name: "no markers", + patch: "*** Begin Patch\n" + + "@@\n" + + "-old\n" + + "+new\n" + + "*** End Patch\n", + wantAdded: nil, + wantUpdated: nil, + wantDeleted: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + added, updated, deleted := parseApplyPatchFiles(tt.patch) + require.Equal(t, tt.wantAdded, added) + require.Equal(t, tt.wantUpdated, updated) + require.Equal(t, tt.wantDeleted, deleted) + }) + } +} diff --git a/cmd/trace/cli/agent/codex/types.go b/cmd/trace/cli/agent/codex/types.go index 6356f64..da34d0b 100644 --- a/cmd/trace/cli/agent/codex/types.go +++ b/cmd/trace/cli/agent/codex/types.go @@ -1,5 +1,7 @@ package codex +import "encoding/json" + // HooksFile represents the .codex/hooks.json structure. type HooksFile struct { Hooks HookEvents `json:"hooks"` @@ -63,6 +65,26 @@ type stopRaw struct { LastAssistantMessage *string `json:"last_assistant_message"` // nullable } +// postToolUseRaw is the JSON structure from PostToolUse hooks. +type postToolUseRaw struct { + SessionID string `json:"session_id"` + TurnID string `json:"turn_id"` + TranscriptPath *string `json:"transcript_path"` // nullable + CWD string `json:"cwd"` + HookEventName string `json:"hook_event_name"` + Model string `json:"model"` + PermissionMode string `json:"permission_mode"` + ToolName string `json:"tool_name"` + ToolUseID string `json:"tool_use_id"` + ToolInput json.RawMessage `json:"tool_input"` + ToolResponse json.RawMessage `json:"tool_response"` +} + +// applyPatchInput is the structure of tool_input for apply_patch. +type applyPatchInput struct { + Patch string `json:"patch"` +} + // derefString safely dereferences a nullable string pointer. func derefString(s *string) string { if s == nil { diff --git a/cmd/trace/cli/agent/event.go b/cmd/trace/cli/agent/event.go index b3759e5..65c313e 100644 --- a/cmd/trace/cli/agent/event.go +++ b/cmd/trace/cli/agent/event.go @@ -40,6 +40,11 @@ const ( // (e.g., Gemini CLI's BeforeModel). The framework stores the model as a hint // for subsequent TurnStart/TurnEnd events in the same session. ModelUpdate + + // ToolUse indicates a tool was used mid-turn (e.g., apply_patch, write_file). + // The framework merges the tool's file list into session.FilesTouched so that + // mid-turn commits have accurate carry-forward data. + ToolUse ) // String returns a human-readable name for the event type. @@ -61,6 +66,8 @@ func (e EventType) String() string { return "SubagentEnd" case ModelUpdate: return "ModelUpdate" + case ToolUse: + return "ToolUse" default: return "Unknown" } @@ -96,6 +103,10 @@ type Event struct { // ToolUseID identifies the tool invocation (for SubagentStart/SubagentEnd events). ToolUseID string + // ToolName identifies the tool that was used (for ToolUse events). + // Agents set this to their native tool identifier (e.g., "apply_patch" for Codex). + ToolName string + // SubagentID identifies the subagent instance (for SubagentEnd events). SubagentID string @@ -109,11 +120,20 @@ type Event struct { SubagentType string TaskDescription string - // ModifiedFiles is a list of file paths modified by a subagent. + // ModifiedFiles is a list of file paths modified by a subagent or tool. // Populated on SubagentEnd events when the agent provides this data - // directly via hook payload (e.g., Cursor's subagentStop). + // directly via hook payload (e.g., Cursor's subagentStop), and on + // ToolUse events for updated files (e.g., Codex apply_patch). ModifiedFiles []string + // NewFiles is a list of file paths newly created by a tool. + // Populated on ToolUse events (e.g., Codex apply_patch "Add File"). + NewFiles []string + + // DeletedFiles is a list of file paths deleted by a tool. + // Populated on ToolUse events (e.g., Codex apply_patch "Delete File"). + DeletedFiles []string + // ResponseMessage is an optional message to display to the user via the agent. ResponseMessage string diff --git a/cmd/trace/cli/api/base_url_test.go b/cmd/trace/cli/api/base_url_test.go index 923e758..474de20 100644 --- a/cmd/trace/cli/api/base_url_test.go +++ b/cmd/trace/cli/api/base_url_test.go @@ -35,7 +35,7 @@ func TestResolveURLFromBase_RejectsNonHTTPScheme(t *testing.T) { func TestRequireSecureURL_AllowsHTTPS(t *testing.T) { t.Parallel() - if err := RequireSecureURL("https://trace.io"); err != nil { + if err := RequireSecureURL("https://entire.io"); err != nil { t.Fatalf("RequireSecureURL(https) = %v, want nil", err) } } diff --git a/cmd/trace/cli/api/client.go b/cmd/trace/cli/api/client.go index 7747e5f..5d087b8 100644 --- a/cmd/trace/cli/api/client.go +++ b/cmd/trace/cli/api/client.go @@ -12,7 +12,7 @@ import ( ) const ( - maxResponseBytes = 1 << 20 + maxResponseBytes = 16 << 20 // 16 MiB – increased from 1 MiB to support large trail/checkpoint payloads (ported from upstream) userAgent = "trace-cli" ) diff --git a/cmd/trace/cli/api/client_test.go b/cmd/trace/cli/api/client_test.go index 59d7825..fc021a0 100644 --- a/cmd/trace/cli/api/client_test.go +++ b/cmd/trace/cli/api/client_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -46,8 +47,8 @@ func TestBearerTransport_InjectsAuthHeader(t *testing.T) { if gotAuth != "Bearer test-token-123" { t.Errorf("Authorization = %q, want %q", gotAuth, "Bearer test-token-123") } - if gotUA != "trace-cli" { - t.Errorf("User-Agent = %q, want %q", gotUA, "trace-cli") + if gotUA != "entire-cli" { + t.Errorf("User-Agent = %q, want %q", gotUA, "entire-cli") } if gotAccept != "application/json" { t.Errorf("Accept = %q, want %q", gotAccept, "application/json") @@ -246,6 +247,31 @@ func TestCheckResponse_ErrorWithJSON(t *testing.T) { } } +func TestCheckResponse_ErrorWithObjectEnvelope(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"error":{"code":"not_found","message":"session not found","field":null,"retryable":false}}`)) //nolint:errcheck // test handler + })) + defer server.Close() + + resp, err := http.Get(server.URL) //nolint:noctx // test helper + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + err = CheckResponse(resp) + if err == nil { + t.Fatal("CheckResponse(404) = nil, want error") + } + if got := err.Error(); got != "API error: session not found (status 404)" { + t.Errorf("error = %q", got) + } +} + func TestCheckResponse_ErrorWithPlainText(t *testing.T) { t.Parallel() @@ -300,3 +326,56 @@ func TestDecodeJSONResponse(t *testing.T) { t.Errorf("Status = %q, want %q", result.Status, "ok") } } + +// TestDecodeJSONResponse_LargeBodyOverOldCap exercises a JSON body whose size +// exceeds the previous 1 MiB read cap. `entire activity` requests up to a +// month of commits, which routinely produces 1.5+ MiB responses; under the +// old cap, io.LimitReader truncated the body mid-JSON and json.Unmarshal +// surfaced "unexpected end of JSON input", masking the real cause. +func TestDecodeJSONResponse_LargeBodyOverOldCap(t *testing.T) { + t.Parallel() + + const itemCount = 4000 // ~2 MiB at ~500 bytes per item + type item struct { + ID string `json:"id"` + Message string `json:"message"` + } + payload := struct { + Items []item `json:"items"` + }{Items: make([]item, itemCount)} + for i := range payload.Items { + payload.Items[i] = item{ + ID: "0123456789abcdef0123456789abcdef0123456789abcdef", + Message: strings.Repeat("x", 400), + } + } + encoded, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + if len(encoded) <= 1<<20 { + t.Fatalf("test payload %d bytes is not over the old 1 MiB cap", len(encoded)) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(encoded) //nolint:errcheck // test handler + })) + defer server.Close() + + resp, err := http.Get(server.URL) //nolint:noctx // test helper + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + var got struct { + Items []item `json:"items"` + } + if err := DecodeJSON(resp, &got); err != nil { + t.Fatalf("DecodeJSON(%d-byte body) = %v, want nil", len(encoded), err) + } + if len(got.Items) != itemCount { + t.Errorf("decoded %d items, want %d", len(got.Items), itemCount) + } +} diff --git a/cmd/trace/cli/api/repositories_test.go b/cmd/trace/cli/api/repositories_test.go index b24d4b3..70453ee 100644 --- a/cmd/trace/cli/api/repositories_test.go +++ b/cmd/trace/cli/api/repositories_test.go @@ -19,8 +19,8 @@ func TestClient_ListRepositories_SendsSortAndDecodesResponse(t *testing.T) { gotAuth = r.Header.Get("Authorization") w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"repositories":[` + //nolint:errcheck // test handler - `{"full_name":"GrayCodeAI/cli","checkpoint_count":12},` + - `{"full_name":"GrayCodeAI/trace.io","checkpoint_count":3}` + + `{"full_name":"entireio/cli","checkpoint_count":12},` + + `{"full_name":"entireio/entire.io","checkpoint_count":3}` + `]}`)) })) defer server.Close() @@ -46,10 +46,10 @@ func TestClient_ListRepositories_SendsSortAndDecodesResponse(t *testing.T) { if len(repos) != 2 { t.Fatalf("len(repos) = %d, want 2", len(repos)) } - if repos[0].FullName != "GrayCodeAI/cli" || repos[0].CheckpointCount != 12 { + if repos[0].FullName != "entireio/cli" || repos[0].CheckpointCount != 12 { t.Errorf("repos[0] = %+v", repos[0]) } - if repos[1].FullName != "GrayCodeAI/trace.io" || repos[1].CheckpointCount != 3 { + if repos[1].FullName != "entireio/entire.io" || repos[1].CheckpointCount != 3 { t.Errorf("repos[1] = %+v", repos[1]) } } diff --git a/cmd/trace/cli/api/trail_types.go b/cmd/trace/cli/api/trail_types.go index 19a4926..9b2e830 100644 --- a/cmd/trace/cli/api/trail_types.go +++ b/cmd/trace/cli/api/trail_types.go @@ -16,8 +16,8 @@ type TrailListResponse struct { // TrailResource represents a single trail from the API. type TrailResource struct { + ID string `json:"id,omitempty"` Number int `json:"number,omitempty"` - TrailID string `json:"trail_id"` Branch string `json:"branch"` Base string `json:"base"` Title string `json:"title"` @@ -42,7 +42,7 @@ type TrailResource struct { func (r *TrailResource) ToMetadata() *trail.Metadata { m := &trail.Metadata{ Number: r.Number, - TrailID: trail.ID(r.TrailID), + TrailID: trail.ID(r.ID), Branch: r.Branch, Base: r.Base, Title: r.Title, diff --git a/cmd/trace/cli/api/trail_types_test.go b/cmd/trace/cli/api/trail_types_test.go new file mode 100644 index 0000000..ea3ab7d --- /dev/null +++ b/cmd/trace/cli/api/trail_types_test.go @@ -0,0 +1,12 @@ +package api + +import "testing" + +func TestTrailResourceToMetadataUsesID(t *testing.T) { + t.Parallel() + + metadata := (&TrailResource{ID: "trail-db-id", Branch: "feature/x"}).ToMetadata() + if got := metadata.TrailID.String(); got != "trail-db-id" { + t.Fatalf("metadata TrailID = %q, want stable API id", got) + } +} diff --git a/cmd/trace/cli/checkpoint/shadow_ref.go b/cmd/trace/cli/checkpoint/shadow_ref.go new file mode 100644 index 0000000..ba127ec --- /dev/null +++ b/cmd/trace/cli/checkpoint/shadow_ref.go @@ -0,0 +1,164 @@ +package checkpoint + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/GrayCodeAI/trace/cmd/trace/cli/internal/flock" + + "github.com/go-git/go-git/v6/plumbing" +) + +// ErrShadowRefBusy is returned by casUpdateShadowBranchRef when the ref has +// moved since the caller read it. Callers retry with a fresh parent. +var ErrShadowRefBusy = errors.New("shadow branch ref moved (CAS mismatch)") + +// shadowRefMaxRetries bounds the WriteTemporary retry loop. With the +// per-shadow-branch flock held, our own writers never collide; this budget +// is purely a safety net against an external `git update-ref` writer that +// repeatedly beats us to the ref. +const shadowRefMaxRetries = 16 + +// shadowRefMaxJitter is the upper bound for randomized backoff between CAS +// retries. Random jitter avoids thundering-herd retry patterns when many +// sessions hit the same shadow branch simultaneously. +const shadowRefMaxJitter = 8 * time.Millisecond + +// repoDirs returns the worktree root and git common dir for the store's +// repository. Callers use the worktree root as cmd.Dir for git invocations +// and the common dir to locate filesystem paths (lock files, loose objects) +// — both without depending on the process cwd. +func (s *GitStore) repoDirs(ctx context.Context) (worktreeRoot, commonDir string, err error) { + wt, err := s.repo.Worktree() + if err != nil { + return "", "", fmt.Errorf("open worktree: %w", err) + } + worktreeRoot = wt.Filesystem.Root() + if worktreeRoot == "" { + return "", "", errors.New("repository worktree filesystem has no root path") + } + commonDir, err = resolveGitCommonDir(ctx, s.repo) + if err != nil { + return "", "", err + } + return worktreeRoot, commonDir, nil +} + +// casUpdateShadowBranchRef atomically updates a shadow branch ref via +// `git update-ref `. Pass plumbing.ZeroHash as expectedHash +// to require the ref to NOT exist (first-checkpoint case). +// +// repoRoot is used as cmd.Dir so the update targets the same repository as +// the rest of WriteTemporary (i.e. s.repo) regardless of the process cwd. +// +// Returns ErrShadowRefBusy when git reports the ref moved since expectedHash +// was observed; callers retry with a fresh parent. Any other failure is +// returned wrapped. +// +// Why shell out: git's ref-locking is the canonical cross-process atomic +// CAS — go-git's CheckAndSetReference doesn't interoperate with native git's +// .lock files, and shadow branches can be touched concurrently by separate +// `trace` hook processes. +func casUpdateShadowBranchRef(ctx context.Context, repoRoot, branchName string, newHash, expectedHash plumbing.Hash) error { + refName := "refs/heads/" + branchName + + // All-zeros OID with the repo's object-format width means "must not + // exist". SHA-1 repos want 40 zeros, SHA-256 repos want 64; mirror + // newHash's hex width so we pick the right one without an extra git call. + newValue := newHash.String() + oldValue := strings.Repeat("0", newHash.HexSize()) + if expectedHash != plumbing.ZeroHash { + oldValue = expectedHash.String() + } + + cmd := exec.CommandContext(ctx, "git", "update-ref", refName, newValue, oldValue) + cmd.Dir = repoRoot + // Force English diagnostics so the CAS-conflict pattern match below + // isn't defeated by a translated stderr message in a non-C locale. + cmd.Env = append(os.Environ(), "LC_ALL=C", "LANG=C") + output, err := cmd.CombinedOutput() + if err == nil { + return nil + } + + out := string(output) + // Git's CAS-failure messages: "cannot lock ref ..." (covers both + // "is at X but expected Y" and "reference already exists" for the + // zero-OID case). Other failures propagate. + if strings.Contains(out, "cannot lock ref") || strings.Contains(out, "but expected") { + return ErrShadowRefBusy + } + return fmt.Errorf("git update-ref %s: %s: %w", refName, strings.TrimSpace(out), err) +} + +// shadowRefBackoff sleeps for a small random jitter before the next CAS +// retry. After several retries the upper bound doubles to slow the +// thundering herd further. Respects context cancellation. +func shadowRefBackoff(ctx context.Context, attempt int) error { + base := shadowRefMaxJitter + if attempt > 4 { + base *= 2 + } + // Add a 1ms floor so the chosen sleep is always non-trivial, even when + // rand.Int64N happens to return 0. + d := time.Duration(rand.Int64N(int64(base))) + time.Millisecond //nolint:gosec // jitter, not security-sensitive + select { + case <-time.After(d): + return nil + case <-ctx.Done(): + return ctx.Err() //nolint:wrapcheck // canonical context cancellation + } +} + +// shadowBranchLockPath returns the per-shadow-branch flock file path. Lock +// files live in /trace-shadow-locks/ so they don't pollute +// the session-state directory. Branch names are slash-escaped because the +// shadow-branch convention "trace/" would otherwise nest directories. +func shadowBranchLockPath(commonDir, branchName string) (string, error) { + lockDir := filepath.Join(commonDir, "trace-shadow-locks") + if err := os.MkdirAll(lockDir, 0o750); err != nil { + return "", fmt.Errorf("create shadow lock directory: %w", err) + } + safe := strings.ReplaceAll(branchName, "/", "_") + return filepath.Join(lockDir, safe+".lock"), nil +} + +// withShadowBranchFlock acquires the per-shadow-branch flock, runs fn, and +// releases the flock. Serializes all WriteTemporary callers that target the +// same shadow branch — across goroutines AND across processes — so the CAS +// in casUpdateShadowBranchRef only sees external writers as contention. +// +// commonDir is the git common directory (from s.repoDirs); it locates the +// lock file independently of the process cwd. +func withShadowBranchFlock(commonDir, branchName string, fn func() error) error { + path, err := shadowBranchLockPath(commonDir, branchName) + if err != nil { + return err + } + release, err := flock.Acquire(path) + if err != nil { + return fmt.Errorf("acquire shadow flock %s: %w", branchName, err) + } + defer release() + return fn() +} + +// tryDeleteLooseObject best-effort removes a loose object file. Used to +// clean up dangling commits created during a CAS-losing attempt. Failures +// (e.g. object already packed by a concurrent gc, or never written as a +// loose object) are ignored — the object will be picked up by the next gc +// pass either way. +func tryDeleteLooseObject(commonDir string, hash plumbing.Hash) { + h := hash.String() + if len(h) < 3 { + return + } + _ = os.Remove(filepath.Join(commonDir, "objects", h[:2], h[2:])) +} diff --git a/cmd/trace/cli/checkpoint/temporary.go b/cmd/trace/cli/checkpoint/temporary.go index c732b08..fed0fbd 100644 --- a/cmd/trace/cli/checkpoint/temporary.go +++ b/cmd/trace/cli/checkpoint/temporary.go @@ -68,20 +68,6 @@ func (s *GitStore) WriteTemporary(ctx context.Context, opts WriteTemporaryOption // Get shadow branch name shadowBranchName := ShadowBranchNameForCommit(opts.BaseCommit, opts.WorktreeID) - // Get or create shadow branch - parentHash, baseTreeHash, err := s.getOrCreateShadowBranch(shadowBranchName) - if err != nil { - return WriteTemporaryResult{}, fmt.Errorf("failed to get shadow branch: %w", err) - } - - // Get the last checkpoint's tree hash for deduplication - var lastTreeHash plumbing.Hash - if parentHash != plumbing.ZeroHash { - if lastCommit, err := s.repo.CommitObject(parentHash); err == nil { - lastTreeHash = lastCommit.TreeHash - } - } - // Collect all files to include var allFiles []string var allDeletedFiles []string @@ -110,39 +96,87 @@ func (s *GitStore) WriteTemporary(ctx context.Context, opts WriteTemporaryOption allDeletedFiles = opts.DeletedFiles } - // Build tree with changes - treeHash, err := s.buildTreeWithChanges(ctx, baseTreeHash, allFiles, allDeletedFiles, opts.MetadataDir, opts.MetadataDirAbs) + // Create checkpoint commit message (constant across retries) + commitMsg := trailers.FormatShadowCommit(opts.CommitMessage, opts.MetadataDir, opts.SessionID) + + repoRoot, commonDir, err := s.repoDirs(ctx) if err != nil { - return WriteTemporaryResult{}, fmt.Errorf("failed to build tree: %w", err) - } + return WriteTemporaryResult{}, fmt.Errorf("failed to resolve repo dirs: %w", err) + } + + var result WriteTemporaryResult + // withShadowBranchFlock serializes all writers targeting this shadow + // branch — across goroutines and across processes — so the inner CAS + // only sees contention from external `git update-ref` callers (rare). + err = withShadowBranchFlock(commonDir, shadowBranchName, func() error { + // Tiny CAS retry budget: with the flock held, races against our own + // code are impossible. Retries cover the pathological case of an + // external writer (a user invoking `git update-ref` manually, etc.). + for attempt := range shadowRefMaxRetries { + parentHash, baseTreeHash, gErr := s.getOrCreateShadowBranch(shadowBranchName) + if gErr != nil { + return fmt.Errorf("failed to get shadow branch: %w", gErr) + } - // Deduplication: skip if tree hash matches the last checkpoint - if lastTreeHash != plumbing.ZeroHash && treeHash == lastTreeHash { - return WriteTemporaryResult{ - CommitHash: parentHash, - Skipped: true, - }, nil - } + // Get the last checkpoint's tree hash for deduplication + var lastTreeHash plumbing.Hash + if parentHash != plumbing.ZeroHash { + if lastCommit, lcErr := s.repo.CommitObject(parentHash); lcErr == nil { + lastTreeHash = lastCommit.TreeHash + } + } - // Create checkpoint commit with trailers - commitMsg := trailers.FormatShadowCommit(opts.CommitMessage, opts.MetadataDir, opts.SessionID) + treeHash, tErr := s.buildTreeWithChanges(ctx, baseTreeHash, allFiles, allDeletedFiles, opts.MetadataDir, opts.MetadataDirAbs) + if tErr != nil { + return fmt.Errorf("failed to build tree: %w", tErr) + } - commitHash, err := s.createCommit(ctx, treeHash, parentHash, commitMsg, opts.AuthorName, opts.AuthorEmail) - if err != nil { - return WriteTemporaryResult{}, fmt.Errorf("failed to create commit: %w", err) - } + // Deduplication: skip if tree hash matches the current shadow tip. + if lastTreeHash != plumbing.ZeroHash && treeHash == lastTreeHash { + result = WriteTemporaryResult{ + CommitHash: parentHash, + Skipped: true, + } + return nil + } - // Update branch reference - refName := plumbing.NewBranchReferenceName(shadowBranchName) - newRef := plumbing.NewHashReference(refName, commitHash) - if err := s.repo.Storer.SetReference(newRef); err != nil { - return WriteTemporaryResult{}, fmt.Errorf("failed to update branch reference: %w", err) - } + commitHash, cErr := s.createCommit(ctx, treeHash, parentHash, commitMsg, opts.AuthorName, opts.AuthorEmail) + if cErr != nil { + return fmt.Errorf("failed to create commit: %w", cErr) + } - return WriteTemporaryResult{ - CommitHash: commitHash, - Skipped: false, - }, nil + refErr := casUpdateShadowBranchRef(ctx, repoRoot, shadowBranchName, commitHash, parentHash) + if refErr == nil { + result = WriteTemporaryResult{ + CommitHash: commitHash, + Skipped: false, + } + return nil + } + if !errors.Is(refErr, ErrShadowRefBusy) { + return fmt.Errorf("failed to update shadow branch reference: %w", refErr) + } + // Our commit is now dangling — best-effort remove it so we don't + // leak loose objects across many losing attempts. + tryDeleteLooseObject(commonDir, commitHash) + if bErr := shadowRefBackoff(ctx, attempt); bErr != nil { + return bErr + } + } + // Retry budget exhausted. With the flock held this means an external + // writer beat us shadowRefMaxRetries times in a row — surface it in + // logs so operators can see a stuck shadow branch. + logging.Warn(logging.WithComponent(ctx, "checkpoint"), + "shadow branch CAS retry budget exhausted", + slog.String("shadow_branch", shadowBranchName), + slog.Int("retries", shadowRefMaxRetries), + ) + return fmt.Errorf("failed to update shadow branch reference after %d CAS retries: %w", shadowRefMaxRetries, ErrShadowRefBusy) + }) + if err != nil { + return WriteTemporaryResult{}, err + } + return result, nil } // ReadTemporary reads the latest checkpoint from a shadow branch. @@ -256,12 +290,6 @@ func (s *GitStore) WriteTemporaryTask(ctx context.Context, opts WriteTemporaryTa // Get shadow branch name shadowBranchName := ShadowBranchNameForCommit(opts.BaseCommit, opts.WorktreeID) - // Get or create shadow branch - parentHash, baseTreeHash, err := s.getOrCreateShadowBranch(shadowBranchName) - if err != nil { - return plumbing.ZeroHash, fmt.Errorf("failed to get shadow branch: %w", err) - } - // Collect all files to include in the commit. // Filter out gitignored files — subagent transcripts may report files like .env // that exist on disk but are gitignored. Without filtering, secrets would leak @@ -271,32 +299,58 @@ func (s *GitStore) WriteTemporaryTask(ctx context.Context, opts WriteTemporaryTa candidateFiles = append(candidateFiles, opts.NewFiles...) allFiles := filterGitIgnoredFiles(ctx, s.repo, candidateFiles) - // Build new tree with code changes (no metadata dir yet) - newTreeHash, err := s.buildTreeWithChanges(ctx, baseTreeHash, allFiles, opts.DeletedFiles, "", "") + repoRoot, commonDir, err := s.repoDirs(ctx) if err != nil { - return plumbing.ZeroHash, fmt.Errorf("failed to build tree: %w", err) + return plumbing.ZeroHash, fmt.Errorf("failed to resolve repo dirs: %w", err) } - // Add task metadata to tree - newTreeHash, err = s.addTaskMetadataToTree(ctx, newTreeHash, opts) - if err != nil { - return plumbing.ZeroHash, fmt.Errorf("failed to add task metadata: %w", err) - } + var resultHash plumbing.Hash + err = withShadowBranchFlock(commonDir, shadowBranchName, func() error { + for attempt := range shadowRefMaxRetries { + parentHash, baseTreeHash, gErr := s.getOrCreateShadowBranch(shadowBranchName) + if gErr != nil { + return fmt.Errorf("failed to get shadow branch: %w", gErr) + } - // Create the commit - commitHash, err := s.createCommit(ctx, newTreeHash, parentHash, opts.CommitMessage, opts.AuthorName, opts.AuthorEmail) - if err != nil { - return plumbing.ZeroHash, fmt.Errorf("failed to create commit: %w", err) - } + newTreeHash, tErr := s.buildTreeWithChanges(ctx, baseTreeHash, allFiles, opts.DeletedFiles, "", "") + if tErr != nil { + return fmt.Errorf("failed to build tree: %w", tErr) + } - // Update shadow branch reference - refName := plumbing.NewBranchReferenceName(shadowBranchName) - ref := plumbing.NewHashReference(refName, commitHash) - if err := s.repo.Storer.SetReference(ref); err != nil { - return plumbing.ZeroHash, fmt.Errorf("failed to update shadow branch reference: %w", err) - } + newTreeHash, tErr = s.addTaskMetadataToTree(ctx, newTreeHash, opts) + if tErr != nil { + return fmt.Errorf("failed to add task metadata: %w", tErr) + } - return commitHash, nil + commitHash, cErr := s.createCommit(ctx, newTreeHash, parentHash, opts.CommitMessage, opts.AuthorName, opts.AuthorEmail) + if cErr != nil { + return fmt.Errorf("failed to create commit: %w", cErr) + } + + refErr := casUpdateShadowBranchRef(ctx, repoRoot, shadowBranchName, commitHash, parentHash) + if refErr == nil { + resultHash = commitHash + return nil + } + if !errors.Is(refErr, ErrShadowRefBusy) { + return fmt.Errorf("failed to update shadow branch reference: %w", refErr) + } + tryDeleteLooseObject(commonDir, commitHash) + if bErr := shadowRefBackoff(ctx, attempt); bErr != nil { + return bErr + } + } + logging.Warn(logging.WithComponent(ctx, "checkpoint"), + "shadow branch CAS retry budget exhausted (task checkpoint)", + slog.String("shadow_branch", shadowBranchName), + slog.Int("retries", shadowRefMaxRetries), + ) + return fmt.Errorf("failed to update shadow branch reference after %d CAS retries: %w", shadowRefMaxRetries, ErrShadowRefBusy) + }) + if err != nil { + return plumbing.ZeroHash, err + } + return resultHash, nil } // addTaskMetadataToTree adds task checkpoint metadata to a git tree. diff --git a/cmd/trace/cli/checkpoint/v2_pending_rotation.go b/cmd/trace/cli/checkpoint/v2_pending_rotation.go index be15e4a..75b1b6b 100644 --- a/cmd/trace/cli/checkpoint/v2_pending_rotation.go +++ b/cmd/trace/cli/checkpoint/v2_pending_rotation.go @@ -18,7 +18,7 @@ import ( const ( pendingV2FullGenerationPublicationVersion = 1 - pendingV2FullGenerationPublicationDirName = "entire-v2-rotations" + pendingV2FullGenerationPublicationDirName = "trace-v2-rotations" pendingV2FullGenerationPublicationFile = "pending.json" pendingV2FullGenerationPublicationLock = "pending.lock" pendingV2FullGenerationPublicationLockTTL = 5 * time.Second diff --git a/cmd/trace/cli/git_operations.go b/cmd/trace/cli/git_operations.go index d9e0695..de72323 100644 --- a/cmd/trace/cli/git_operations.go +++ b/cmd/trace/cli/git_operations.go @@ -128,7 +128,7 @@ func IsOnDefaultBranch(ctx context.Context) (bool, string, error) { // If we couldn't determine from remote, use common defaults if defaultBranch == "" { // Check if current branch is a common default name - if currentBranch == "main" || currentBranch == "master" { + if currentBranch == defaultBaseBranch || currentBranch == masterBaseBranch { return true, currentBranch, nil } return false, currentBranch, nil @@ -151,11 +151,11 @@ func getDefaultBranchFromRemote(repo *git.Repository) string { } // Fallback: check if origin/main or origin/master exists - if _, err := repo.Reference(plumbing.NewRemoteReferenceName("origin", "main"), true); err == nil { - return "main" + if _, err := repo.Reference(plumbing.NewRemoteReferenceName("origin", defaultBaseBranch), true); err == nil { + return defaultBaseBranch } - if _, err := repo.Reference(plumbing.NewRemoteReferenceName("origin", "master"), true); err == nil { - return "master" + if _, err := repo.Reference(plumbing.NewRemoteReferenceName("origin", masterBaseBranch), true); err == nil { + return masterBaseBranch } return "" diff --git a/cmd/trace/cli/hooks_cmd.go b/cmd/trace/cli/hooks_cmd.go index d9c809a..1a28f25 100644 --- a/cmd/trace/cli/hooks_cmd.go +++ b/cmd/trace/cli/hooks_cmd.go @@ -16,6 +16,7 @@ import ( _ "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/factoryaidroid" _ "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/geminicli" _ "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/opencode" + _ "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/pi" _ "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/vogon" // support external agents diff --git a/cmd/trace/cli/integration_test/codex_post_tool_use_test.go b/cmd/trace/cli/integration_test/codex_post_tool_use_test.go index c6b3add..0d666af 100644 --- a/cmd/trace/cli/integration_test/codex_post_tool_use_test.go +++ b/cmd/trace/cli/integration_test/codex_post_tool_use_test.go @@ -26,7 +26,7 @@ func TestCodexPostToolUse_PopulatesFilesTouched(t *testing.T) { env := NewRepoWithCommit(t) sessionID := "test-codex-post-tool-use" - statePath := filepath.Join(env.RepoDir, ".git", "entire-sessions", sessionID+".json") + statePath := filepath.Join(env.RepoDir, ".git", "trace-sessions", sessionID+".json") require.NoError(t, os.MkdirAll(filepath.Dir(statePath), 0o755)) // Pre-create state with AgentType=Codex. We skip UserPromptSubmit because @@ -85,7 +85,7 @@ func TestCodexPostToolUse_NonMutatingToolIsNoop(t *testing.T) { env := NewRepoWithCommit(t) sessionID := "test-codex-post-tool-use-noop" - statePath := filepath.Join(env.RepoDir, ".git", "entire-sessions", sessionID+".json") + statePath := filepath.Join(env.RepoDir, ".git", "trace-sessions", sessionID+".json") require.NoError(t, os.MkdirAll(filepath.Dir(statePath), 0o755)) initialState := map[string]any{ diff --git a/cmd/trace/cli/integration_test/cursor_forwarding_test.go b/cmd/trace/cli/integration_test/cursor_forwarding_test.go index d9de435..75cd335 100644 --- a/cmd/trace/cli/integration_test/cursor_forwarding_test.go +++ b/cmd/trace/cli/integration_test/cursor_forwarding_test.go @@ -23,7 +23,7 @@ func TestDispatcher_ForwardedStopFromNonOwnerIsSkipped(t *testing.T) { env := NewRepoWithCommit(t) sessionID := "test-cursor-forward-stop" - statePath := filepath.Join(env.RepoDir, ".git", "entire-sessions", sessionID+".json") + statePath := filepath.Join(env.RepoDir, ".git", "trace-sessions", sessionID+".json") require.NoError(t, os.MkdirAll(filepath.Dir(statePath), 0o755)) // Pre-record state with AgentType=Cursor: the firing claude-code hook @@ -71,7 +71,7 @@ func TestDispatcher_ForwardedSessionEndFromNonOwnerIsSkipped(t *testing.T) { env := NewRepoWithCommit(t) sessionID := "test-cursor-forward-sessionend" - statePath := filepath.Join(env.RepoDir, ".git", "entire-sessions", sessionID+".json") + statePath := filepath.Join(env.RepoDir, ".git", "trace-sessions", sessionID+".json") require.NoError(t, os.MkdirAll(filepath.Dir(statePath), 0o755)) initialState := map[string]any{ diff --git a/cmd/trace/cli/integration_test/hooks.go b/cmd/trace/cli/integration_test/hooks.go index be924e2..048814d 100644 --- a/cmd/trace/cli/integration_test/hooks.go +++ b/cmd/trace/cli/integration_test/hooks.go @@ -1481,3 +1481,73 @@ func (env *TestEnv) CopyTranscriptToTraceTmp(sessionID, transcriptPath string) { env.T.Fatalf("CopyTranscriptToTraceTmp: failed to write transcript to %q: %v", destPath, err) } } + +// CodexHookRunner executes Codex CLI hooks in the test environment. +type CodexHookRunner struct { + RepoDir string + T interface { + Helper() + Fatalf(format string, args ...interface{}) + Logf(format string, args ...interface{}) + } +} + +// NewCodexHookRunner creates a hook runner for Codex hooks in the given repo. +func NewCodexHookRunner(repoDir string, t interface { + Helper() + Fatalf(format string, args ...interface{}) + Logf(format string, args ...interface{}) +}, +) *CodexHookRunner { + return &CodexHookRunner{ + RepoDir: repoDir, + T: t, + } +} + +// runCodexHook runs a Codex hook by name with the given JSON input via stdin. +func (r *CodexHookRunner) runCodexHook(hookName string, input interface{}) error { + r.T.Helper() + + inputJSON, err := json.Marshal(input) + if err != nil { + return fmt.Errorf("failed to marshal hook input: %w", err) + } + + cmd := exec.Command(getTestBinary(), "hooks", "codex", hookName) + cmd.Dir = r.RepoDir + cmd.Stdin = bytes.NewReader(inputJSON) + cmd.Env = testutil.GitIsolatedEnv() + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("codex hook %s failed: %w\nInput: %s\nOutput: %s", + hookName, err, inputJSON, output) + } + + r.T.Logf("Codex hook %s output: %s", hookName, output) + return nil +} + +// SimulateCodexPostToolUseApplyPatch simulates a Codex PostToolUse hook +// for an apply_patch tool invocation. The patch string is wrapped in the +// Codex tool_input envelope before being dispatched. +func (r *CodexHookRunner) SimulateCodexPostToolUseApplyPatch(sessionID, cwd, patch string) error { + r.T.Helper() + + input := map[string]any{ + "session_id": sessionID, + "turn_id": "t1", + "transcript_path": nil, + "cwd": cwd, + "hook_event_name": "PostToolUse", + "model": "gpt-5", + "permission_mode": "default", + "tool_name": "apply_patch", + "tool_use_id": "call-patch", + "tool_input": map[string]any{"patch": patch}, + "tool_response": "Patch applied successfully.", + } + + return r.runCodexHook("post-tool-use", input) +} diff --git a/cmd/trace/cli/integration_test/review_test.go b/cmd/trace/cli/integration_test/review_test.go index 92273b2..42c2a48 100644 --- a/cmd/trace/cli/integration_test/review_test.go +++ b/cmd/trace/cli/integration_test/review_test.go @@ -243,7 +243,7 @@ func TestReview_MissingSkillAtSpawn_ErrorsCleanly(t *testing.T) { if !strings.Contains(output, "not installed") { t.Errorf("stderr should mention skill not installed; got:\n%s", output) } - if _, err := os.Stat(filepath.Join(env.RepoDir, ".git", "entire-sessions", "review-pending.json")); !os.IsNotExist(err) { + if _, err := os.Stat(filepath.Join(env.RepoDir, ".git", "trace-sessions", "review-pending.json")); !os.IsNotExist(err) { t.Errorf("pending marker should not exist; stat err=%v", err) } } diff --git a/cmd/trace/cli/internal/flock/flock_unix.go b/cmd/trace/cli/internal/flock/flock_unix.go new file mode 100644 index 0000000..680153f --- /dev/null +++ b/cmd/trace/cli/internal/flock/flock_unix.go @@ -0,0 +1,30 @@ +//go:build unix + +// Package flock provides a small cross-process advisory-lock primitive built +// on POSIX flock (Unix) / LockFileEx (Windows). It exists so that checkpoint +// and strategy can both serialize on shared resources without one taking +// the other as an import dependency. +package flock + +import ( + "fmt" + "os" + "syscall" +) + +// Acquire takes an exclusive advisory lock on path, creating the file if +// needed. The returned release closes the file, which drops the flock. +// Callers must invoke release exactly once. The lock file persists between +// runs — flock state is held by the file descriptor, not by the inode on +// disk — so the lockfile contents are immaterial. +func Acquire(path string) (release func(), err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) //nolint:gosec // caller is responsible for path validation + if err != nil { + return nil, fmt.Errorf("open flock: %w", err) + } + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { //nolint:gosec // file descriptors are non-negative; standard Go pattern for syscall.Flock + _ = f.Close() + return nil, fmt.Errorf("flock: %w", err) + } + return func() { _ = f.Close() }, nil +} diff --git a/cmd/trace/cli/internal/flock/flock_unix_test.go b/cmd/trace/cli/internal/flock/flock_unix_test.go new file mode 100644 index 0000000..5f9cc6c --- /dev/null +++ b/cmd/trace/cli/internal/flock/flock_unix_test.go @@ -0,0 +1,98 @@ +//go:build unix + +package flock + +import ( + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestAcquire_BasicAcquireRelease(t *testing.T) { + t.Parallel() + dir := t.TempDir() + lockPath := filepath.Join(dir, "test.lock") + + release, err := Acquire(lockPath) + if err != nil { + t.Fatalf("Acquire() error = %v", err) + } + + // File should exist + if _, statErr := os.Stat(lockPath); statErr != nil { + t.Errorf("lock file should exist after Acquire: %v", statErr) + } + + release() +} + +func TestAcquire_BlocksConcurrent(t *testing.T) { + t.Parallel() + dir := t.TempDir() + lockPath := filepath.Join(dir, "test.lock") + + release1, err := Acquire(lockPath) + if err != nil { + t.Fatalf("first Acquire() error = %v", err) + } + defer release1() + + var acquired atomic.Bool + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + // This should block until release1 is called. + rel2, err := Acquire(lockPath) + if err != nil { + t.Errorf("second Acquire() error = %v", err) + return + } + acquired.Store(true) + rel2() + }() + + // Give the goroutine time to attempt the lock. + time.Sleep(50 * time.Millisecond) + if acquired.Load() { + t.Error("second Acquire should have blocked while first lock is held") + } + + release1() + wg.Wait() + + if !acquired.Load() { + t.Error("second Acquire should have succeeded after release") + } +} + +func TestAcquire_ReleaseAllowsReacquire(t *testing.T) { + t.Parallel() + dir := t.TempDir() + lockPath := filepath.Join(dir, "test.lock") + + rel1, err := Acquire(lockPath) + if err != nil { + t.Fatalf("first Acquire() error = %v", err) + } + rel1() + + // Should be able to acquire again immediately. + rel2, err := Acquire(lockPath) + if err != nil { + t.Fatalf("second Acquire() after release error = %v", err) + } + rel2() +} + +func TestAcquire_InvalidPath(t *testing.T) { + t.Parallel() + _, err := Acquire("/nonexistent/dir/lock.file") + if err == nil { + t.Error("Acquire() should return error for invalid path") + } +} diff --git a/cmd/trace/cli/strategy/state_lock_windows.go b/cmd/trace/cli/internal/flock/flock_windows.go similarity index 52% rename from cmd/trace/cli/strategy/state_lock_windows.go rename to cmd/trace/cli/internal/flock/flock_windows.go index e6cff2b..93cc0c4 100644 --- a/cmd/trace/cli/strategy/state_lock_windows.go +++ b/cmd/trace/cli/internal/flock/flock_windows.go @@ -1,6 +1,6 @@ //go:build windows -package strategy +package flock import ( "fmt" @@ -9,18 +9,18 @@ import ( "golang.org/x/sys/windows" ) -// acquireStateFileLock takes an exclusive lock on path via Windows -// LockFileEx. The returned release unlocks and closes the file. Callers -// must call release exactly once. -func acquireStateFileLock(path string) (release func(), err error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) //nolint:gosec // path built from validated session ID +// Acquire takes an exclusive lock on path via Windows LockFileEx. The +// returned release unlocks and closes the file. Callers must invoke release +// exactly once. +func Acquire(path string) (release func(), err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) //nolint:gosec // caller is responsible for path validation if err != nil { - return nil, fmt.Errorf("open state lock: %w", err) + return nil, fmt.Errorf("open flock: %w", err) } overlapped := new(windows.Overlapped) if err := windows.LockFileEx(windows.Handle(f.Fd()), windows.LOCKFILE_EXCLUSIVE_LOCK, 0, 1, 0, overlapped); err != nil { _ = f.Close() - return nil, fmt.Errorf("lock state lock: %w", err) + return nil, fmt.Errorf("lock flock: %w", err) } return func() { _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, 1, 0, overlapped) diff --git a/cmd/trace/cli/jsonutil/write.go b/cmd/trace/cli/jsonutil/write.go new file mode 100644 index 0000000..38f4f02 --- /dev/null +++ b/cmd/trace/cli/jsonutil/write.go @@ -0,0 +1,72 @@ +package jsonutil + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" +) + +// WriteFileAtomic writes data to filePath atomically by writing to a temp file +// in the same directory, fsyncing it, renaming into place, and fsyncing the +// parent directory. A crash or signal mid-write leaves the original file +// intact rather than a truncated partial — important for config files like +// .entire/settings.json that callers expect to remain parseable across +// interrupted writes. +// +// The fsync between Write and Close guarantees the temp file's bytes are on +// disk before the rename takes effect; without it, some filesystems (notably +// ext4 with non-default mount options) can surface the rename as completed +// while the file is still empty after a hard crash. +// +// The parent-directory fsync after rename guarantees the rename's directory +// entry is durable. Without it, the file contents are on disk but the +// directory may still point to the pre-rename state after a crash, so the +// "leaves the original intact" promise would silently break. Windows does +// not support directory fsync; we make this step best-effort so the call +// does not fail on platforms where the operation is a no-op. +// +// perm is applied to the temp file via Chmod before rename so the final file +// lands with the requested permission regardless of the temp file's default. +func WriteFileAtomic(filePath string, data []byte, perm fs.FileMode) error { + dir := filepath.Dir(filePath) + base := filepath.Base(filePath) + tmp, err := os.CreateTemp(dir, base+".*.tmp") + if err != nil { + return fmt.Errorf("create temp for %s: %w", filePath, err) + } + tmpName := tmp.Name() + removeTmp := true + defer func() { + if removeTmp { + _ = os.Remove(tmpName) + } + }() + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return fmt.Errorf("write temp for %s: %w", filePath, err) + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return fmt.Errorf("sync temp for %s: %w", filePath, err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("close temp for %s: %w", filePath, err) + } + if err := os.Chmod(tmpName, perm); err != nil { + return fmt.Errorf("chmod temp for %s: %w", filePath, err) + } + if err := os.Rename(tmpName, filePath); err != nil { + return fmt.Errorf("rename temp to %s: %w", filePath, err) + } + removeTmp = false + // Best-effort: the rename succeeded, so don't propagate failures here. + // Directory fsync isn't supported on Windows, and on POSIX an error + // after a successful rename would mislead callers who already have the + // file in place. + if d, err := os.Open(dir); err == nil { //nolint:gosec // G304: dir is filepath.Dir of caller-supplied filePath, not user input + _ = d.Sync() //nolint:errcheck // best-effort directory fsync; failure does not roll back the rename + _ = d.Close() + } + return nil +} diff --git a/cmd/trace/cli/jsonutil/write_test.go b/cmd/trace/cli/jsonutil/write_test.go new file mode 100644 index 0000000..192f819 --- /dev/null +++ b/cmd/trace/cli/jsonutil/write_test.go @@ -0,0 +1,157 @@ +package jsonutil + +import ( + "errors" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestWriteFileAtomic_CreatesNewFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + target := filepath.Join(dir, "out.json") + data := []byte(`{"hello":"world"}`) + + if err := WriteFileAtomic(target, data, 0o644); err != nil { + t.Fatalf("WriteFileAtomic: %v", err) + } + + got, err := os.ReadFile(target) + if err != nil { + t.Fatalf("read back: %v", err) + } + if string(got) != string(data) { + t.Errorf("content mismatch: got %q want %q", got, data) + } +} + +func TestWriteFileAtomic_ReplacesExistingFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + target := filepath.Join(dir, "out.json") + if err := os.WriteFile(target, []byte("old contents"), 0o644); err != nil { + t.Fatalf("seed file: %v", err) + } + + newData := []byte("new contents") + if err := WriteFileAtomic(target, newData, 0o644); err != nil { + t.Fatalf("WriteFileAtomic: %v", err) + } + + got, err := os.ReadFile(target) + if err != nil { + t.Fatalf("read back: %v", err) + } + if string(got) != string(newData) { + t.Errorf("content not replaced: got %q want %q", got, newData) + } +} + +// AppliesPermission verifies the Chmod-before-rename step actually lands the +// requested mode on the final file. os.CreateTemp defaults to 0o600 so +// without the Chmod a 0o644 caller would silently get a tighter mode. +func TestWriteFileAtomic_AppliesPermission(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("POSIX permission bits are not meaningful on Windows") + } + dir := t.TempDir() + target := filepath.Join(dir, "out.json") + + if err := WriteFileAtomic(target, []byte("x"), 0o600); err != nil { + t.Fatalf("WriteFileAtomic: %v", err) + } + + info, err := os.Stat(target) + if err != nil { + t.Fatalf("stat: %v", err) + } + if got := info.Mode().Perm(); got != 0o600 { + t.Errorf("perm: got %#o want %#o", got, 0o600) + } +} + +// LeavesNoTempOnSuccess guards against the removeTmp defer being skipped or +// the temp suffix changing in a way that breaks cleanup. +func TestWriteFileAtomic_LeavesNoTempOnSuccess(t *testing.T) { + t.Parallel() + dir := t.TempDir() + target := filepath.Join(dir, "out.json") + + if err := WriteFileAtomic(target, []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFileAtomic: %v", err) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + if len(entries) != 1 { + names := make([]string, 0, len(entries)) + for _, e := range entries { + names = append(names, e.Name()) + } + t.Errorf("expected exactly one entry in dir, got %d: %v", len(entries), names) + } + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".tmp") { + t.Errorf("leftover temp file: %s", e.Name()) + } + } +} + +// CleansUpTempOnRenameFailure reaches the rename step and forces it to fail +// (renaming a regular file onto a non-empty directory is rejected on every +// POSIX filesystem, and on Windows). The removeTmp defer must clear the +// orphan so /tmp doesn't accumulate junk across many failed writes. +func TestWriteFileAtomic_CleansUpTempOnRenameFailure(t *testing.T) { + t.Parallel() + dir := t.TempDir() + target := filepath.Join(dir, "out.json") + if err := os.Mkdir(target, 0o755); err != nil { + t.Fatalf("mkdir target: %v", err) + } + if err := os.WriteFile(filepath.Join(target, "occupant"), []byte("x"), 0o644); err != nil { + t.Fatalf("seed dir: %v", err) + } + + err := WriteFileAtomic(target, []byte("x"), 0o644) + if err == nil { + t.Fatal("expected error when target is a non-empty directory") + } + + info, statErr := os.Stat(target) + if statErr != nil { + t.Fatalf("stat target: %v", statErr) + } + if !info.IsDir() { + t.Error("target should still be a directory after failed rename") + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".tmp") { + t.Errorf("leftover temp file after failed rename: %s", e.Name()) + } + } +} + +func TestWriteFileAtomic_ParentMissing(t *testing.T) { + t.Parallel() + dir := t.TempDir() + target := filepath.Join(dir, "does-not-exist", "out.json") + + err := WriteFileAtomic(target, []byte("x"), 0o644) + if err == nil { + t.Fatal("expected error when parent dir is missing") + } + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("expected ErrNotExist; got: %v", err) + } +} diff --git a/cmd/trace/cli/lifecycle.go b/cmd/trace/cli/lifecycle.go index 657c23f..0333154 100644 --- a/cmd/trace/cli/lifecycle.go +++ b/cmd/trace/cli/lifecycle.go @@ -55,6 +55,8 @@ func DispatchLifecycleEvent(ctx context.Context, ag agent.Agent, event *agent.Ev return handleLifecycleSubagentEnd(ctx, ag, event) case agent.ModelUpdate: return handleLifecycleModelUpdate(ctx, ag, event) + case agent.ToolUse: + return handleLifecycleToolUse(ctx, ag, event) default: return fmt.Errorf("unknown lifecycle event type: %d", event.Type) } @@ -204,6 +206,39 @@ func handleLifecycleModelUpdate(ctx context.Context, ag agent.Agent, event *agen return nil } +// handleLifecycleToolUse merges a tool's file-change lists into session state. +// This keeps FilesTouched accurate during a turn so mid-turn commits have +// correct carry-forward data. +func handleLifecycleToolUse(ctx context.Context, ag agent.Agent, event *agent.Event) error { + logCtx := logging.WithAgent(logging.WithComponent(ctx, "lifecycle"), ag.Name()) + + if event.SessionID == "" { + return nil + } + + totalFiles := len(event.ModifiedFiles) + len(event.NewFiles) + len(event.DeletedFiles) + if totalFiles == 0 { + return nil + } + + logging.Info( + logCtx, "tool-use: recording files touched", + slog.String("session_id", event.SessionID), + slog.String("tool", event.ToolName), + slog.Int("modified", len(event.ModifiedFiles)), + slog.Int("added", len(event.NewFiles)), + slog.Int("deleted", len(event.DeletedFiles)), + ) + + if err := strategy.RecordFilesTouched(ctx, event.SessionID, event.ModifiedFiles, event.NewFiles, event.DeletedFiles); err != nil { + // RecordFilesTouched no-ops on ErrStateNotFound — log and continue. + logging.Debug(logCtx, "tool-use: RecordFilesTouched skipped", + slog.String("error", err.Error())) + } + + return nil +} + // handleLifecycleTurnStart handles turn start: captures pre-prompt state, // ensures strategy setup, initializes session. func handleLifecycleTurnStart(ctx context.Context, ag agent.Agent, event *agent.Event) error { diff --git a/cmd/trace/cli/resume.go b/cmd/trace/cli/resume.go index 62254f9..36711e3 100644 --- a/cmd/trace/cli/resume.go +++ b/cmd/trace/cli/resume.go @@ -597,7 +597,7 @@ func findBranchCheckpoints(repo *git.Repository, branchName string) (*branchChec defaultBranch := getDefaultBranchFromRemote(repo) if defaultBranch == "" { // Fallback: try common names - for _, name := range []string{"main", "master"} { + for _, name := range []string{defaultBaseBranch, masterBaseBranch} { if _, err := repo.Reference(plumbing.NewBranchReferenceName(name), true); err == nil { defaultBranch = name break diff --git a/cmd/trace/cli/review/migration.go b/cmd/trace/cli/review/migration.go new file mode 100644 index 0000000..1bd0d47 --- /dev/null +++ b/cmd/trace/cli/review/migration.go @@ -0,0 +1,285 @@ +package review + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + + "github.com/GrayCodeAI/trace/cmd/trace/cli/logging" + "github.com/GrayCodeAI/trace/cmd/trace/cli/settings" +) + +type projectReviewSettings struct { + path string + raw map[string]json.RawMessage + review json.RawMessage + fixAgent json.RawMessage + hasReview bool + hasFixAgent bool +} + +func maybePromptReviewSettingsMigration( + ctx context.Context, + out io.Writer, + errOut io.Writer, + canPrompt bool, + promptYN func(context.Context, string, bool) (bool, error), +) error { + project, ok, err := loadProjectReviewSettings(ctx) + if err != nil { + return err + } + if !ok { + return nil + } + + // Skip the prompt entirely if the user has already declined. Without this, + // teams who intentionally commit review prefs would be re-prompted on + // every invocation of `entire review`. + prefs, prefsErr := settings.LoadClonePreferences(ctx) + if prefsErr != nil { + return fmt.Errorf("load review preferences for migration: %w", prefsErr) + } + if prefs != nil && prefs.ReviewMigrationDismissed { + return nil + } + + // Bail before prompting if .entire/settings.local.json already has review + // keys. settings.local.json overrides clone-local preferences (mergeJSON + // wholesale-replaces the review map), so migrating without cleaning the + // local file first would silently nullify the migration on the very next + // settings.Load — the user clicks "yes", their config moves to clone + // prefs, then the local override hides it. Better to surface the + // precondition up front than to leave the user wondering why their + // migrated config disappeared. + // + // Intentionally does NOT set ReviewMigrationDismissed: this is a fixable + // precondition, not a user-rejected migration; the prompt should fire + // again on the next run after the user cleans settings.local.json. + if localHas, localPath, localErr := localSettingsHasReviewKeys(ctx); localErr != nil { + return fmt.Errorf("inspect local settings for migration: %w", localErr) + } else if localHas { + fmt.Fprintln(errOut, "Cannot migrate review preferences: .entire/settings.local.json also has review keys.") + fmt.Fprintf(errOut, "Those override clone-local preferences and would mask the migration. Remove the\n") + fmt.Fprintf(errOut, "`review` / `review_fix_agent` keys from %s, then re-run `entire review`.\n", localPath) + return nil + } + + if !canPrompt { + // Log at Warn so operators tailing .entire/logs/ catch the pending + // migration on scripted/CI invocations where the stderr hint may + // scroll past unnoticed. + logging.Warn(ctx, "review migration pending: project settings has review keys that may be committed", + slog.String("project_settings_path", project.path), + slog.Bool("has_review", project.hasReview), + slog.Bool("has_fix_agent", project.hasFixAgent)) + fmt.Fprintln(errOut, "Review preferences are stored in project settings (.entire/settings.json).") + fmt.Fprintln(errOut, "These are typically committed and may be visible to teammates.") + fmt.Fprintln(errOut, "Run `entire review --edit` interactively to move them to clone-local preferences.") + return nil + } + + if promptYN == nil { + promptYN = realPromptYN + } + migrate, err := promptYN(ctx, "Review preferences are stored in project settings (.entire/settings.json), which is typically committed. Move them to clone-local preferences so they stay private?", false) + if err != nil { + return fmt.Errorf("review settings migration prompt: %w", err) + } + if !migrate { + if prefs == nil { + prefs = &settings.ClonePreferences{} + } + prefs.ReviewMigrationDismissed = true + if err := settings.SaveClonePreferences(ctx, prefs); err != nil { + return fmt.Errorf("save migration dismissal: %w", err) + } + return nil + } + + moved, err := migrateProjectReviewSettings(ctx, project) + if err != nil { + return err + } + if moved { + fmt.Fprintln(out, "Moved review preferences from project settings to clone-local preferences.") + } else { + fmt.Fprintln(out, "Removed unused review keys from project settings; nothing to move.") + } + return nil +} + +func loadProjectReviewSettings(ctx context.Context) (*projectReviewSettings, bool, error) { + path, raw, exists, err := settings.LoadProjectRaw(ctx) + if err != nil { + return nil, false, fmt.Errorf("review migration: %w", err) + } + if !exists { + return nil, false, nil + } + + reviewRaw, hasReview := raw["review"] + fixAgentRaw, hasFixAgent := raw["review_fix_agent"] + if !hasReview && !hasFixAgent { + return nil, false, nil + } + return &projectReviewSettings{ + path: path, + raw: raw, + review: reviewRaw, + fixAgent: fixAgentRaw, + hasReview: hasReview, + hasFixAgent: hasFixAgent, + }, true, nil +} + +// migrateProjectReviewSettings copies review keys from the project settings +// file into clone-local preferences and strips them from the project file. +// +// Returns moved=true when any review data was copied into prefs. When the +// project file's review keys are empty/null (or fully conflict with existing +// prefs, which is rejected upstream), moved=false but the project keys are +// still stripped as cleanup. +// +// Write ordering: prefs are saved first (atomic), then the project file is +// rewritten (atomic). Both writes use temp-then-rename so a crash mid-write +// leaves the original file intact rather than truncated. If the project +// rewrite fails after the prefs write succeeded, prefs precedence covers +// the gap until the next run. +func migrateProjectReviewSettings(ctx context.Context, project *projectReviewSettings) (moved bool, err error) { + if project == nil { + return false, nil + } + + prefs, err := settings.LoadClonePreferences(ctx) + if err != nil { + return false, fmt.Errorf("load review preferences for migration: %w", err) + } + if prefs == nil { + prefs = &settings.ClonePreferences{} + } + + preferencesChanged := false + if project.hasReview && !isJSONNull(project.review) { + var projectReview map[string]settings.ReviewConfig + if err := json.Unmarshal(project.review, &projectReview); err != nil { + return false, fmt.Errorf("parsing project review settings: %w", err) + } + if len(projectReview) > 0 { + merged, mergedOK, conflicts := mergeProjectReviewIntoPrefs(prefs.Review, projectReview) + if len(conflicts) > 0 { + return false, fmt.Errorf( + "review settings exist in both %s and clone-local preferences for agent(s) %v; "+ + "reconcile manually by removing the redundant keys from %s, then re-run `entire review`", + project.path, conflicts, project.path, + ) + } + if mergedOK { + prefs.Review = merged + preferencesChanged = true + } + } + } + if project.hasFixAgent && !isJSONNull(project.fixAgent) { + var fixAgent string + if err := json.Unmarshal(project.fixAgent, &fixAgent); err != nil { + return false, fmt.Errorf("parsing project review_fix_agent: %w", err) + } + if fixAgent != "" { + if prefs.ReviewFixAgent != "" && prefs.ReviewFixAgent != fixAgent { + return false, fmt.Errorf( + "review_fix_agent differs between %s (%q) and clone-local preferences (%q); "+ + "reconcile manually by removing review_fix_agent from %s, then re-run `entire review`", + project.path, fixAgent, prefs.ReviewFixAgent, project.path, + ) + } + if prefs.ReviewFixAgent == "" { + prefs.ReviewFixAgent = fixAgent + preferencesChanged = true + } + } + } + + if preferencesChanged { + if err := settings.SaveClonePreferences(ctx, prefs); err != nil { + return false, fmt.Errorf("save review preferences for migration: %w", err) + } + } + + delete(project.raw, "review") + delete(project.raw, "review_fix_agent") + if err := settings.SaveProjectRaw(project.path, project.raw); err != nil { + return false, fmt.Errorf("save project settings after review migration: %w", err) + } + return preferencesChanged, nil +} + +// mergeProjectReviewIntoPrefs merges projectReview into the current prefs map. +// Per-agent conflicts (same key, different value) are surfaced rather than +// silently resolved — the caller can then refuse the migration with a clear +// message. Non-overlapping entries are merged. Returns ok=false when nothing +// would change (prefs already had every project entry verbatim). +func mergeProjectReviewIntoPrefs(prefs, projectReview map[string]settings.ReviewConfig) (merged map[string]settings.ReviewConfig, ok bool, conflicts []string) { + merged = make(map[string]settings.ReviewConfig, len(prefs)+len(projectReview)) + for k, v := range prefs { + merged[k] = v + } + changed := false + for k, projectV := range projectReview { + if existing, present := merged[k]; present { + if !reviewConfigEqual(existing, projectV) { + conflicts = append(conflicts, k) + } + continue + } + merged[k] = projectV + changed = true + } + if len(conflicts) > 0 { + return nil, false, conflicts + } + return merged, changed, nil +} + +func reviewConfigEqual(a, b settings.ReviewConfig) bool { + if a.Prompt != b.Prompt { + return false + } + if len(a.Skills) != len(b.Skills) { + return false + } + for i := range a.Skills { + if a.Skills[i] != b.Skills[i] { + return false + } + } + return true +} + +func isJSONNull(raw json.RawMessage) bool { + return bytes.Equal(bytes.TrimSpace(raw), []byte("null")) +} + +// localSettingsHasReviewKeys reports whether .entire/settings.local.json +// exists and contains either a "review" or "review_fix_agent" key. Both keys +// override clone-local preferences via mergeJSON's wholesale-replace path, +// so the migration must surface their presence rather than silently produce +// a state where the migrated config never takes effect. +// +// Returns the absolute path of the local settings file too, so callers can +// quote the exact location in the warning they show the user. +func localSettingsHasReviewKeys(ctx context.Context) (has bool, path string, err error) { + path, raw, exists, loadErr := settings.LoadLocalRaw(ctx) + if loadErr != nil { + return false, path, fmt.Errorf("local settings review-keys check: %w", loadErr) + } + if !exists { + return false, path, nil + } + _, hasReview := raw["review"] + _, hasFixAgent := raw["review_fix_agent"] + return hasReview || hasFixAgent, path, nil +} diff --git a/cmd/trace/cli/review/migration_test.go b/cmd/trace/cli/review/migration_test.go new file mode 100644 index 0000000..da446f8 --- /dev/null +++ b/cmd/trace/cli/review/migration_test.go @@ -0,0 +1,421 @@ +package review + +import ( + "bytes" + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/GrayCodeAI/trace/cmd/trace/cli/session" + "github.com/GrayCodeAI/trace/cmd/trace/cli/settings" + "github.com/GrayCodeAI/trace/cmd/trace/cli/testutil" +) + +func TestReviewSettingsMigration_MovesProjectReviewToClonePreferences(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectSettings := []byte(`{ + "enabled": true, + "log_level": "debug", + "review": {"claude-code": {"skills": ["/review"], "prompt": "project"}}, + "review_fix_agent": "claude-code" + }`) + projectPath := filepath.Join(entireDir, "settings.json") + if err := os.WriteFile(projectPath, projectSettings, 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + + prompted := false + promptQuestion := "" + var out bytes.Buffer + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, func(_ context.Context, question string, _ bool) (bool, error) { + prompted = true + promptQuestion = question + return true, nil + }); err != nil { + t.Fatalf("migration: %v", err) + } + if !prompted { + t.Fatal("expected migration prompt") + } + for _, want := range []string{"project settings", "clone-local preferences", "typically committed"} { + if !strings.Contains(promptQuestion, want) { + t.Fatalf("migration prompt = %q, want it to mention %q", promptQuestion, want) + } + } + + prefs, err := settings.LoadClonePreferences(context.Background()) + if err != nil { + t.Fatalf("load preferences: %v", err) + } + if got := prefs.Review["claude-code"].Prompt; got != "project" { + t.Fatalf("migrated prompt = %q, want project", got) + } + if prefs.ReviewFixAgent != "claude-code" { + t.Fatalf("ReviewFixAgent = %q, want claude-code", prefs.ReviewFixAgent) + } + + raw := map[string]json.RawMessage{} + data, err := os.ReadFile(projectPath) + if err != nil { + t.Fatalf("read project settings: %v", err) + } + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal project settings: %v", err) + } + if _, ok := raw["review"]; ok { + t.Fatalf("project review key was not removed: %s", data) + } + if _, ok := raw["review_fix_agent"]; ok { + t.Fatalf("project review_fix_agent key was not removed: %s", data) + } + if _, ok := raw["log_level"]; !ok { + t.Fatalf("unrelated project settings were not preserved: %s", data) + } +} + +// TestReviewSettingsMigration_MergesNonOverlappingPrefs verifies that when the +// project file has review keys for an agent NOT present in clone-local prefs, +// the migration merges them in. Previously the migration silently dropped any +// project config when prefs already had any review entry — that was data loss. +func TestReviewSettingsMigration_MergesNonOverlappingPrefs(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectPath := filepath.Join(entireDir, "settings.json") + projectSettings := []byte(`{ + "enabled": true, + "review": {"project-agent": {"prompt": "project"}} + }`) + if err := os.WriteFile(projectPath, projectSettings, 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + if err := settings.SaveClonePreferences(context.Background(), &settings.ClonePreferences{ + Review: map[string]settings.ReviewConfig{ + "local-agent": {Prompt: "local"}, + }, + }); err != nil { + t.Fatalf("seed preferences: %v", err) + } + + var out bytes.Buffer + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, func(context.Context, string, bool) (bool, error) { + return true, nil + }); err != nil { + t.Fatalf("migration: %v", err) + } + + prefs, err := settings.LoadClonePreferences(context.Background()) + if err != nil { + t.Fatalf("load preferences: %v", err) + } + if got := prefs.Review["local-agent"].Prompt; got != "local" { + t.Fatalf("local prompt = %q, want preserved as %q", got, "local") + } + if got := prefs.Review["project-agent"].Prompt; got != "project" { + t.Fatalf("project prompt = %q, want merged in as %q", got, "project") + } + + data, err := os.ReadFile(projectPath) + if err != nil { + t.Fatalf("read project settings: %v", err) + } + raw := map[string]json.RawMessage{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal project settings: %v", err) + } + if _, ok := raw["review"]; ok { + t.Fatalf("project review key was not removed: %s", data) + } +} + +// TestReviewSettingsMigration_RefusesConflictingPrefs verifies that when both +// the project file and clone-local prefs have review config for the SAME agent +// with DIFFERENT values, the migration aborts with a clear error rather than +// silently dropping one side. The user must reconcile manually. +func TestReviewSettingsMigration_RefusesConflictingPrefs(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectPath := filepath.Join(entireDir, "settings.json") + projectSettings := []byte(`{ + "enabled": true, + "review": {"claude-code": {"prompt": "project"}} + }`) + if err := os.WriteFile(projectPath, projectSettings, 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + if err := settings.SaveClonePreferences(context.Background(), &settings.ClonePreferences{ + Review: map[string]settings.ReviewConfig{ + "claude-code": {Prompt: "local"}, + }, + }); err != nil { + t.Fatalf("seed preferences: %v", err) + } + + var out bytes.Buffer + err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, func(context.Context, string, bool) (bool, error) { + return true, nil + }) + if err == nil { + t.Fatal("expected migration to refuse conflicting prefs") + } + if !strings.Contains(err.Error(), "claude-code") { + t.Errorf("error = %q, want it to name the conflicting agent (claude-code)", err.Error()) + } + if !strings.Contains(err.Error(), "reconcile manually") { + t.Errorf("error = %q, want it to guide manual reconciliation", err.Error()) + } + + // Project file must NOT have been rewritten on the conflict path. + data, err := os.ReadFile(projectPath) + if err != nil { + t.Fatalf("read project settings: %v", err) + } + if !bytes.Contains(data, []byte("claude-code")) { + t.Fatalf("project file was modified despite conflict abort: %s", data) + } + + // Clone prefs must be unchanged. + prefs, err := settings.LoadClonePreferences(context.Background()) + if err != nil { + t.Fatalf("load preferences: %v", err) + } + if got := prefs.Review["claude-code"].Prompt; got != "local" { + t.Errorf("local prompt = %q, want unchanged as %q", got, "local") + } +} + +// TestReviewSettingsMigration_NoMoveCleansUpKeys verifies the cleanup-only +// path: project has only `null` values for review keys, so nothing actually +// moves, but the project keys are still stripped and the success message +// reflects that distinction. +func TestReviewSettingsMigration_NoMoveCleansUpKeys(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectPath := filepath.Join(entireDir, "settings.json") + if err := os.WriteFile(projectPath, []byte(`{ + "enabled": true, + "review": null, + "review_fix_agent": null + }`), 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + + var out bytes.Buffer + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, func(context.Context, string, bool) (bool, error) { + return true, nil + }); err != nil { + t.Fatalf("migration: %v", err) + } + if !strings.Contains(out.String(), "Removed unused review keys") { + t.Errorf("output = %q, want the cleanup-only message", out.String()) + } + + data, err := os.ReadFile(projectPath) + if err != nil { + t.Fatalf("read project settings: %v", err) + } + raw := map[string]json.RawMessage{} + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("unmarshal project settings: %v", err) + } + if _, ok := raw["review"]; ok { + t.Fatalf("project review key was not removed: %s", data) + } +} + +// TestReviewSettingsMigration_DeclinePersistsDismissal verifies that declining +// the prompt records ReviewMigrationDismissed in clone-local prefs, and that a +// subsequent invocation does NOT re-prompt. Without this, teams who +// intentionally commit review prefs would be re-prompted on every command. +func TestReviewSettingsMigration_DeclinePersistsDismissal(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectPath := filepath.Join(entireDir, "settings.json") + projectSettings := []byte(`{ + "enabled": true, + "review": {"claude-code": {"prompt": "project"}} + }`) + if err := os.WriteFile(projectPath, projectSettings, 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + + // First invocation: user declines. + var out bytes.Buffer + promptCount := 0 + declineThenFail := func(context.Context, string, bool) (bool, error) { + promptCount++ + return false, nil + } + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, declineThenFail); err != nil { + t.Fatalf("first invocation: %v", err) + } + if promptCount != 1 { + t.Errorf("first invocation prompted %d times, want 1", promptCount) + } + + // Dismissal must be persisted. + prefs, err := settings.LoadClonePreferences(context.Background()) + if err != nil { + t.Fatalf("load preferences: %v", err) + } + if prefs == nil || !prefs.ReviewMigrationDismissed { + t.Fatalf("ReviewMigrationDismissed = false, want true after decline (prefs = %+v)", prefs) + } + + // Project file must be untouched on decline. + data, err := os.ReadFile(projectPath) + if err != nil { + t.Fatalf("read project settings: %v", err) + } + if !bytes.Contains(data, []byte("claude-code")) { + t.Errorf("project file was modified on decline: %s", data) + } + + // Second invocation: must NOT re-prompt. + failIfPrompted := func(context.Context, string, bool) (bool, error) { + t.Fatal("prompt should not be called when dismissal is persisted") + return false, nil + } + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, failIfPrompted); err != nil { + t.Fatalf("second invocation: %v", err) + } +} + +func TestReviewSettingsMigration_SkipsWhenProjectHasNoReviewKeys(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectPath := filepath.Join(entireDir, "settings.json") + if err := os.WriteFile(projectPath, []byte(`{"enabled":true,"log_level":"debug"}`), 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + + var out bytes.Buffer + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &out, true, func(context.Context, string, bool) (bool, error) { + t.Fatal("prompt should not be called") + return false, nil + }); err != nil { + t.Fatalf("migration: %v", err) + } + + preferencesPath, err := settings.ClonePreferencesPath(context.Background()) + if err != nil { + t.Fatalf("preferences path: %v", err) + } + if _, err := os.Stat(preferencesPath); !os.IsNotExist(err) { + t.Fatalf("preferences file exists after no-op migration: %v", err) + } +} + +// TestReviewSettingsMigration_BailsOnLocalSettingsReviewKeys pins the +// precondition: when .entire/settings.local.json has review keys, those +// override clone-local preferences via mergeJSON's wholesale-replace path, +// so the migration must surface the conflict up front rather than silently +// produce a migrated-but-masked state. Bailing also intentionally does NOT +// set ReviewMigrationDismissed — this is a fixable precondition, not a +// rejected migration, and the user should be re-prompted after cleaning +// settings.local.json. +func TestReviewSettingsMigration_BailsOnLocalSettingsReviewKeys(t *testing.T) { + tmp := t.TempDir() + testutil.InitRepo(t, tmp) + t.Chdir(tmp) + session.ClearGitCommonDirCache() + + entireDir := filepath.Join(tmp, ".entire") + if err := os.MkdirAll(entireDir, 0o750); err != nil { + t.Fatalf("mkdir .entire: %v", err) + } + projectPath := filepath.Join(entireDir, "settings.json") + projectSettings := []byte(`{ + "enabled": true, + "review": {"claude-code": {"prompt": "project"}} + }`) + if err := os.WriteFile(projectPath, projectSettings, 0o600); err != nil { + t.Fatalf("write project settings: %v", err) + } + localPath := filepath.Join(entireDir, "settings.local.json") + localSettings := []byte(`{"review": {"local-agent": {"prompt": "local"}}}`) + if err := os.WriteFile(localPath, localSettings, 0o600); err != nil { + t.Fatalf("write local settings: %v", err) + } + + var out, errOut bytes.Buffer + if err := maybePromptReviewSettingsMigration(context.Background(), &out, &errOut, true, func(context.Context, string, bool) (bool, error) { + t.Fatal("prompt should not be called when settings.local.json has review keys") + return false, nil + }); err != nil { + t.Fatalf("migration: %v", err) + } + + stderr := errOut.String() + for _, want := range []string{"settings.local.json", "review", "Remove"} { + if !strings.Contains(stderr, want) { + t.Errorf("stderr = %q, want it to mention %q", stderr, want) + } + } + + // Project file must NOT have been rewritten — the bail path leaves + // everything in place so the user can clean settings.local.json and + // re-run. + got, err := os.ReadFile(projectPath) + if err != nil { + t.Fatalf("read project settings: %v", err) + } + if !bytes.Contains(got, []byte(`"claude-code"`)) { + t.Fatalf("project file was modified despite bail; got: %s", got) + } + + // Dismissal must NOT be persisted — the user didn't choose to dismiss, + // they hit a fixable precondition. Next run should re-prompt. + prefs, err := settings.LoadClonePreferences(context.Background()) + if err != nil { + t.Fatalf("load preferences: %v", err) + } + if prefs != nil && prefs.ReviewMigrationDismissed { + t.Fatalf("ReviewMigrationDismissed = true after bail; should not persist a fixable precondition as dismissal") + } +} diff --git a/cmd/trace/cli/review/tui_text.go b/cmd/trace/cli/review/tui_text.go index 36b9ba3..0c30557 100644 --- a/cmd/trace/cli/review/tui_text.go +++ b/cmd/trace/cli/review/tui_text.go @@ -28,6 +28,43 @@ func sanitizeDisplayText(s string) string { }, stripped) } +// wrapDisplayWidth (ported from upstream for tui_text_test). +func wrapDisplayWidth(s string, width int) []string { + if width <= 0 { + return nil + } + s = strings.TrimRight(s, "\n") + if s == "" { + return nil + } + paragraphs := strings.Split(s, "\n") + out := make([]string, 0, len(paragraphs)) + for _, p := range paragraphs { + clean := sanitizeDisplayText(p) + if clean == "" { + out = append(out, "") + continue + } + words := strings.Fields(clean) + line := "" + for _, w := range words { + if len(line)+len(w)+1 > width && line != "" { + out = append(out, line) + line = w + } else { + if line != "" { + line += " " + } + line += w + } + } + if line != "" { + out = append(out, line) + } + } + return out +} + func padDisplayWidth(s string, width int) string { return padDisplayWidthWith(s, width, " ") } diff --git a/cmd/trace/cli/review/tui_text_test.go b/cmd/trace/cli/review/tui_text_test.go new file mode 100644 index 0000000..275dccb --- /dev/null +++ b/cmd/trace/cli/review/tui_text_test.go @@ -0,0 +1,86 @@ +package review + +import ( + "reflect" + "testing" +) + +func TestWrapDisplayWidth(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + width int + want []string + }{ + { + name: "empty input returns nil", + in: "", + width: 80, + want: nil, + }, + { + name: "zero width returns nil", + in: "anything", + width: 0, + want: nil, + }, + { + name: "negative width returns nil", + in: "anything", + width: -10, + want: nil, + }, + { + name: "short single line fits", + in: "hello", + width: 80, + want: []string{"hello"}, + }, + { + name: "long line wraps to width", + in: "aaaa bbbb cccc", + width: 5, + want: []string{"aaaa", "bbbb", "cccc"}, + }, + { + name: "embedded newline preserved as paragraph break", + in: "a\n\nb", + width: 80, + want: []string{"a", "", "b"}, + }, + { + name: "trailing newline does not produce phantom blank line", + in: "text\n", + width: 80, + want: []string{"text"}, + }, + { + name: "multiple trailing newlines collapsed", + in: "text\n\n\n", + width: 80, + want: []string{"text"}, + }, + { + name: "ANSI escape stripped from output", + in: "\x1b[31mred\x1b[0m text", + width: 80, + want: []string{"red text"}, + }, + { + name: "control chars stripped", + in: "a\x07b", + width: 80, + want: []string{"ab"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := wrapDisplayWidth(tt.in, tt.width) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrapDisplayWidth(%q, %d) = %#v, want %#v", tt.in, tt.width, got, tt.want) + } + }) + } +} diff --git a/cmd/trace/cli/settings/settings.go b/cmd/trace/cli/settings/settings.go index d80d3a2..1bdfa6d 100644 --- a/cmd/trace/cli/settings/settings.go +++ b/cmd/trace/cli/settings/settings.go @@ -17,6 +17,7 @@ import ( "github.com/GrayCodeAI/trace/cmd/trace/cli/jsonutil" "github.com/GrayCodeAI/trace/cmd/trace/cli/paths" + "github.com/GrayCodeAI/trace/cmd/trace/cli/session" ) const ( @@ -24,6 +25,9 @@ const ( TraceSettingsFile = ".trace/settings.json" // TraceSettingsLocalFile is the path to the local settings override file (not committed) TraceSettingsLocalFile = ".trace/settings.local.json" + // ClonePreferencesFile is the path inside the git common dir for clone-local preferences + // (review migration state, etc.). Adapted from upstream "entire/preferences.json". + ClonePreferencesFile = "trace/preferences.json" // defaultGenerationRetentionDays is the default retention window for archived // checkpoints v2 raw-transcript generations when no override is configured. defaultGenerationRetentionDays = 14 @@ -955,3 +959,131 @@ func saveToFile(ctx context.Context, settings *TraceSettings, filePath string) e } return nil } + +// --- Clone-local preferences and raw settings helpers (ported from upstream for review migration) --- + +// ClonePreferences holds per-clone (non-committed) preferences, primarily for +// the review feature (e.g. which agent to use for review fixes, migration dismissal state). +type ClonePreferences struct { + Review map[string]ReviewConfig `json:"review,omitempty"` + ReviewFixAgent string `json:"review_fix_agent,omitempty"` + + // ReviewMigrationDismissed records that the user declined the one-shot + // migration of review keys from project settings to clone-local prefs. + // Once true, `trace review` stops prompting on every invocation. + ReviewMigrationDismissed bool `json:"review_migration_dismissed,omitempty"` +} + +// LoadProjectRaw reads .trace/settings.json as a generic JSON object. +// Used by review migration to move keys without loading the full typed struct. +func LoadProjectRaw(ctx context.Context) (path string, raw map[string]json.RawMessage, exists bool, err error) { + path, err = paths.AbsPath(ctx, TraceSettingsFile) + if err != nil { + path = TraceSettingsFile + } + data, readErr := os.ReadFile(path) //nolint:gosec + if readErr != nil { + if os.IsNotExist(readErr) { + return path, map[string]json.RawMessage{}, false, nil + } + return path, nil, false, fmt.Errorf("reading project settings: %w", readErr) + } + raw = map[string]json.RawMessage{} + if err := json.Unmarshal(data, &raw); err != nil { + return path, nil, true, fmt.Errorf("parsing project settings: %w", err) + } + return path, raw, true, nil +} + +// LoadLocalRaw reads .trace/settings.local.json as a generic JSON object. +func LoadLocalRaw(ctx context.Context) (path string, raw map[string]json.RawMessage, exists bool, err error) { + path, err = paths.AbsPath(ctx, TraceSettingsLocalFile) + if err != nil { + path = TraceSettingsLocalFile + } + data, readErr := os.ReadFile(path) //nolint:gosec + if readErr != nil { + if os.IsNotExist(readErr) { + return path, map[string]json.RawMessage{}, false, nil + } + return path, nil, false, fmt.Errorf("reading local settings: %w", readErr) + } + raw = map[string]json.RawMessage{} + if err := json.Unmarshal(data, &raw); err != nil { + return path, nil, true, fmt.Errorf("parsing local settings: %w", err) + } + return path, raw, true, nil +} + +// SaveProjectRaw writes a generic JSON object back to .trace/settings.json atomically. +func SaveProjectRaw(path string, raw map[string]json.RawMessage) error { + data, err := jsonutil.MarshalIndentWithNewline(raw, "", " ") + if err != nil { + return fmt.Errorf("marshal project settings: %w", err) + } + if err := jsonutil.WriteFileAtomic(path, data, 0o644); err != nil { + return fmt.Errorf("writing project settings: %w", err) + } + return nil +} + +// ClonePreferencesPath returns the path to trace/preferences.json inside the git common dir. +func ClonePreferencesPath(ctx context.Context) (string, error) { + commonDir, err := session.GetGitCommonDir(ctx) + if err != nil { + return "", err + } + return filepath.Join(commonDir, ClonePreferencesFile), nil +} + +// LoadClonePreferences loads clone-local preferences from the git common dir. +func LoadClonePreferences(ctx context.Context) (*ClonePreferences, error) { + path, err := ClonePreferencesPath(ctx) + if err != nil { + return nil, err + } + return loadClonePreferencesFromFile(path) +} + +// SaveClonePreferences saves clone-local preferences to the git common dir. +func SaveClonePreferences(ctx context.Context, prefs *ClonePreferences) error { + path, err := ClonePreferencesPath(ctx) + if err != nil { + return err + } + return saveClonePreferencesToFile(prefs, path) +} + +func loadClonePreferencesFromFile(filePath string) (*ClonePreferences, error) { + prefs := &ClonePreferences{} + data, err := os.ReadFile(filePath) //nolint:gosec + if err != nil { + if os.IsNotExist(err) { + return prefs, nil + } + return nil, fmt.Errorf("%w", err) + } + // Lenient decode (unknown fields are ignored) — same rationale as upstream. + if err := json.Unmarshal(data, prefs); err != nil { + return nil, fmt.Errorf("parsing preferences file: %w", err) + } + return prefs, nil +} + +func saveClonePreferencesToFile(prefs *ClonePreferences, filePath string) error { + if prefs == nil { + prefs = &ClonePreferences{} + } + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o750); err != nil { + return fmt.Errorf("creating preferences directory: %w", err) + } + data, err := jsonutil.MarshalIndentWithNewline(prefs, "", " ") + if err != nil { + return fmt.Errorf("marshaling preferences: %w", err) + } + if err := jsonutil.WriteFileAtomic(filePath, data, 0o644); err != nil { + return fmt.Errorf("writing preferences file: %w", err) + } + return nil +} diff --git a/cmd/trace/cli/strategy/manual_commit_concurrent_test.go b/cmd/trace/cli/strategy/manual_commit_concurrent_test.go new file mode 100644 index 0000000..7fd8e36 --- /dev/null +++ b/cmd/trace/cli/strategy/manual_commit_concurrent_test.go @@ -0,0 +1,286 @@ +package strategy + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/GrayCodeAI/trace/cmd/trace/cli/checkpoint" + "github.com/GrayCodeAI/trace/cmd/trace/cli/paths" + "github.com/GrayCodeAI/trace/cmd/trace/cli/testutil" + + "github.com/go-git/go-git/v6" + "github.com/go-git/go-git/v6/plumbing" + "github.com/go-git/go-git/v6/plumbing/filemode" +) + +// TestSaveStep_ConcurrentSessionsSameShadowBranch reproduces the parallel-agents +// scenario behind the Stop-hook error +// +// failed to write temporary checkpoint: failed to build tree: +// failed to apply changes in .trace: failed to read tree: object not found +// +// Multiple sessions in the same worktree on the same base commit all hash to the +// same shadow branch name. SaveStep is serialized per-session-ID via +// acquireSessionGate but there is no shadow-branch-wide lock, so the ref update +// at the end of WriteTemporary races. +// +// Each goroutine writes to a unique agent-XX.txt file with unique content per +// step, so every checkpoint produces a distinct tree hash — i.e. the dedup +// short-circuit in WriteTemporary never fires. That invariant is what lets us +// assert per-session StepCount == checkpointsPerWorker below; if a future +// change ever lands two identical checkpoints in a row, the StepCount +// assertion (not just the commit-count check) will catch it. +// +// Assertions: +// - no SaveStep returns an error +// - every session's persisted StepCount equals checkpointsPerWorker (no +// checkpoint was skipped or lost) +// - the resulting shadow branch is internally consistent: every commit +// reachable from the ref has a tree where every directory entry resolves +// (no "object not found" anywhere in the chain) +// - the shadow branch commit count equals numSessions * checkpointsPerWorker +func TestSaveStep_ConcurrentSessionsSameShadowBranch(t *testing.T) { + const ( + numSessions = 8 + checkpointsPerWorker = 4 + ) + + dir := t.TempDir() + testutil.InitRepo(t, dir) + testutil.WriteFile(t, dir, "seed.txt", "seed\n") + testutil.GitAdd(t, dir, "seed.txt") + testutil.GitCommit(t, dir, "initial commit") + + t.Chdir(dir) + paths.ClearWorktreeRootCache() + + type session struct { + id string + metadataDir string + metadataDirAbs string + file string + } + sessions := make([]session, numSessions) + for i := range sessions { + id := fmt.Sprintf("2026-05-14-concurrent-%02d", i) + md := paths.TraceMetadataDir + "/" + id + sessions[i] = session{ + id: id, + metadataDir: md, + metadataDirAbs: filepath.Join(dir, md), + file: fmt.Sprintf("agent-%02d.txt", i), + } + testutil.WriteFile(t, dir, md+"/"+paths.TranscriptFileName, + "{\"type\":\"human\",\"message\":{\"content\":\"start\"}}\n") + } + + type goroutineErr struct { + session string + step int + err error + } + errCh := make(chan goroutineErr, numSessions*(checkpointsPerWorker+1)) + start := make(chan struct{}) + + var wg sync.WaitGroup + for i := range sessions { + sess := sessions[i] + wg.Go(func() { + ctx := context.Background() + // Each goroutine owns its own strategy + repo handle, mirroring the + // production case where every hook invocation is a fresh process. + s := NewManualCommitStrategy() + + if err := s.InitializeSession(ctx, sess.id, "Claude Code", "", "", ""); err != nil { + errCh <- goroutineErr{session: sess.id, step: -1, err: fmt.Errorf("InitializeSession: %w", err)} + return + } + + // Wait for all goroutines to be ready, then start together to widen + // the race window for the SetReference contention. + <-start + + for step := range checkpointsPerWorker { + content := fmt.Sprintf("session=%s step=%d\n", sess.id, step) + if err := writeFileForRaceTest(filepath.Join(dir, sess.file), content); err != nil { + errCh <- goroutineErr{session: sess.id, step: step, err: fmt.Errorf("write worker file: %w", err)} + return + } + transcriptLine := fmt.Sprintf("{\"type\":\"assistant\",\"step\":%d}\n", step) + transcriptPath := filepath.Join(sess.metadataDirAbs, paths.TranscriptFileName) + if err := writeFileForRaceTest(transcriptPath, transcriptLine); err != nil { + errCh <- goroutineErr{session: sess.id, step: step, err: fmt.Errorf("write transcript: %w", err)} + return + } + + var modified, newFiles []string + if step == 0 { + newFiles = []string{sess.file} + } else { + modified = []string{sess.file} + } + + err := s.SaveStep(ctx, StepContext{ + SessionID: sess.id, + ModifiedFiles: modified, + NewFiles: newFiles, + MetadataDir: sess.metadataDir, + MetadataDirAbs: sess.metadataDirAbs, + CommitMessage: fmt.Sprintf("Checkpoint %d for %s", step, sess.id), + AuthorName: "Test", + AuthorEmail: "test@example.com", + }) + if err != nil { + errCh <- goroutineErr{session: sess.id, step: step, err: fmt.Errorf("SaveStep: %w", err)} + return + } + } + }) + } + + close(start) + wg.Wait() + close(errCh) + + for ge := range errCh { + t.Errorf("session %s step %d: %v", ge.session, ge.step, ge.err) + } + if t.Failed() { + return + } + + // Per-session invariant: every SaveStep call should have landed a + // checkpoint (no skips from the dedup short-circuit). StepCount is + // incremented in SaveStep only when WriteTemporary returns Skipped=false, + // so this catches a future test change that accidentally writes + // duplicate-content checkpoints — which would surface as a misleading + // "commits were lost" message in the commit-count check below. + stateStrategy := NewManualCommitStrategy() + for _, sess := range sessions { + state, err := stateStrategy.loadSessionState(context.Background(), sess.id) + if err != nil { + t.Errorf("load state for %s: %v", sess.id, err) + continue + } + if state == nil { + t.Errorf("missing state for %s", sess.id) + continue + } + if state.StepCount != checkpointsPerWorker { + t.Errorf("session %s StepCount = %d, want %d", sess.id, state.StepCount, checkpointsPerWorker) + } + } + + // Verify the shadow branch is internally consistent. + repo, err := git.PlainOpen(dir) + if err != nil { + t.Fatalf("open repo: %v", err) + } + + shadowBranches := listShadowBranches(t, repo) + if len(shadowBranches) == 0 { + t.Fatal("expected at least one shadow branch after SaveStep, found none") + } + if len(shadowBranches) > 1 { + names := make([]string, 0, len(shadowBranches)) + for _, ref := range shadowBranches { + names = append(names, ref.Name().Short()) + } + t.Fatalf("expected sessions to share a single shadow branch, got %d: %v", len(shadowBranches), names) + } + + commits := walkShadowBranchAssertConsistent(t, repo, shadowBranches[0]) + + // Commit-count check: every distinct checkpoint we issued should have + // landed on the shadow branch. See the test-level comment for why dedup + // can't quietly defeat this assertion. + expected := numSessions * checkpointsPerWorker + switch { + case commits > expected: + t.Errorf("walked %d commits, more than the %d SaveStep calls — accounting bug", commits, expected) + case commits < expected: + t.Errorf("walked %d commits but issued %d SaveStep calls — %d commits were lost", + commits, expected, expected-commits) + default: + t.Logf("walked %d commits matching %d SaveStep calls — no checkpoints lost", commits, expected) + } +} + +func listShadowBranches(t *testing.T, repo *git.Repository) []plumbing.Reference { + t.Helper() + refs, err := repo.References() + if err != nil { + t.Fatalf("list refs: %v", err) + } + var out []plumbing.Reference + err = refs.ForEach(func(ref *plumbing.Reference) error { + if ref.Name().IsBranch() && strings.HasPrefix(ref.Name().Short(), checkpoint.ShadowBranchPrefix) && + ref.Name().Short() != paths.MetadataBranchName { + out = append(out, *ref) + } + return nil + }) + if err != nil { + t.Fatalf("iterate refs: %v", err) + } + return out +} + +// walkShadowBranchAssertConsistent walks every commit reachable from the shadow +// branch ref and asserts every tree (and recursively every subtree) is in the +// object database. Returns the number of commits visited. +func walkShadowBranchAssertConsistent(t *testing.T, repo *git.Repository, ref plumbing.Reference) int { + t.Helper() + visited := make(map[plumbing.Hash]bool) + count := 0 + hash := ref.Hash() + for hash != plumbing.ZeroHash { + if visited[hash] { + t.Fatalf("shadow branch %s: cycle at commit %s", ref.Name().Short(), hash) + } + visited[hash] = true + count++ + + commit, err := repo.CommitObject(hash) + if err != nil { + t.Fatalf("shadow branch %s: commit %s unreadable: %v", ref.Name().Short(), hash, err) + } + walkTreeAssertConsistent(t, repo, commit.TreeHash, "/") + + if len(commit.ParentHashes) == 0 { + break + } + hash = commit.ParentHashes[0] + } + return count +} + +func walkTreeAssertConsistent(t *testing.T, repo *git.Repository, hash plumbing.Hash, path string) { + t.Helper() + tree, err := repo.TreeObject(hash) + if err != nil { + t.Fatalf("tree %s at %s unreadable: %v", hash, path, err) + } + for _, entry := range tree.Entries { + if entry.Mode == filemode.Dir { + walkTreeAssertConsistent(t, repo, entry.Hash, path+entry.Name+"/") + } + } +} + +// writeFileForRaceTest is a goroutine-safe alternative to testutil.WriteFile, +// scoped to this test file. testutil.WriteFile calls t.Fatalf, which doesn't +// fail the test cleanly from a sub-goroutine. Name kept long and specific so +// it can't accidentally shadow a more general helper added to the package +// later. +func writeFileForRaceTest(absPath, content string) error { + if err := os.MkdirAll(filepath.Dir(absPath), 0o755); err != nil { + return fmt.Errorf("mkdir: %w", err) + } + return os.WriteFile(absPath, []byte(content), 0o644) +} diff --git a/cmd/trace/cli/strategy/manual_commit_git.go b/cmd/trace/cli/strategy/manual_commit_git.go index 4e0e7e8..30a7619 100644 --- a/cmd/trace/cli/strategy/manual_commit_git.go +++ b/cmd/trace/cli/strategy/manual_commit_git.go @@ -9,6 +9,7 @@ import ( "sort" "github.com/GrayCodeAI/trace/cmd/trace/cli/agent" + "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/types" "github.com/GrayCodeAI/trace/cmd/trace/cli/checkpoint" "github.com/GrayCodeAI/trace/cmd/trace/cli/checkpoint/id" "github.com/GrayCodeAI/trace/cmd/trace/cli/logging" @@ -31,159 +32,139 @@ func (s *ManualCommitStrategy) SaveStep(ctx context.Context, step StepContext) e } openRepoSpan.End() - // Extract session ID from metadata dir sessionID := filepath.Base(step.MetadataDir) - // Load or initialize session state - _, loadStateSpan := perf.Start(ctx, "load_session_state") - state, err := s.loadSessionState(ctx, sessionID) - if err != nil { - loadStateSpan.RecordError(err) - loadStateSpan.End() - return fmt.Errorf("failed to load session state: %w", err) + // Initialize the session if no state exists yet. Done outside + // MutateSessionState because the helper bails with ErrStateNotFound on + // missing state — initialization establishes the file the helper will + // then mutate under lock. + if err := s.ensureSessionInitialized(ctx, repo, sessionID, step.AgentType); err != nil { + return err } - // Initialize if state is nil OR BaseCommit is empty (can happen with partial state from warnings) - if state == nil || state.BaseCommit == "" { - agentType := resolveAgentType(step.AgentType, state) - state, err = s.initializeSession(ctx, repo, sessionID, agentType, "", "", "") // No transcript/prompt/model in fallback + + mutErr := MutateSessionState(ctx, sessionID, func(state *SessionState) error { + _, migrateSpan := perf.Start(ctx, "migrate_shadow_branch") + if _, _, err := s.migrateShadowBranchIfNeeded(ctx, repo, state); err != nil { + migrateSpan.RecordError(err) + migrateSpan.End() + return fmt.Errorf("failed to check/migrate shadow branch: %w", err) + } + migrateSpan.End() + + store, err := s.getCheckpointStore() if err != nil { - loadStateSpan.RecordError(err) - loadStateSpan.End() - return fmt.Errorf("failed to initialize session: %w", err) + return fmt.Errorf("failed to get checkpoint store: %w", err) } - } - loadStateSpan.End() - // Check if HEAD has changed (e.g., Claude did a rebase via tool call) and migrate if needed - _, migrateSpan := perf.Start(ctx, "migrate_shadow_branch") - if err := s.migrateAndPersistIfNeeded(ctx, repo, state); err != nil { - migrateSpan.RecordError(err) - migrateSpan.End() - return err - } - migrateSpan.End() + shadowBranchName := checkpoint.ShadowBranchNameForCommit(state.BaseCommit, state.WorktreeID) + branchExisted := store.ShadowBranchExists(state.BaseCommit, state.WorktreeID) - // Get checkpoint store - store, err := s.getCheckpointStore() - if err != nil { - return fmt.Errorf("failed to get checkpoint store: %w", err) - } + var promptAttr PromptAttribution + if state.PendingPromptAttribution != nil { + promptAttr = *state.PendingPromptAttribution + state.PendingPromptAttribution = nil + } else { + promptAttr = PromptAttribution{CheckpointNumber: state.StepCount + 1} + } - // Check if shadow branch exists to report whether we created it - shadowBranchName := checkpoint.ShadowBranchNameForCommit(state.BaseCommit, state.WorktreeID) - branchExisted := store.ShadowBranchExists(state.BaseCommit, state.WorktreeID) - - // Use the pending attribution calculated at prompt start (in InitializeSession) - // This was calculated BEFORE the agent made changes, so it accurately captures user edits - var promptAttr PromptAttribution - if state.PendingPromptAttribution != nil { - promptAttr = *state.PendingPromptAttribution - state.PendingPromptAttribution = nil // Clear after use - } else { - // No pending attribution (e.g., first checkpoint or session initialized without it) - promptAttr = PromptAttribution{CheckpointNumber: state.StepCount + 1} - } + attrLogCtx := logging.WithComponent(ctx, "attribution") + logging.Debug(attrLogCtx, "prompt attribution at checkpoint save", + slog.Int("checkpoint_number", promptAttr.CheckpointNumber), + slog.Int("user_added", promptAttr.UserLinesAdded), + slog.Int("user_removed", promptAttr.UserLinesRemoved), + slog.Int("agent_added", promptAttr.AgentLinesAdded), + slog.Int("agent_removed", promptAttr.AgentLinesRemoved), + slog.String("session_id", sessionID)) + + _, writeCheckpointSpan := perf.Start(ctx, "write_temporary_checkpoint") + isFirstCheckpointOfSession := state.StepCount == 0 + result, err := store.WriteTemporary(ctx, checkpoint.WriteTemporaryOptions{ + SessionID: sessionID, + BaseCommit: state.BaseCommit, + WorktreeID: state.WorktreeID, + ModifiedFiles: step.ModifiedFiles, + NewFiles: step.NewFiles, + DeletedFiles: step.DeletedFiles, + MetadataDir: step.MetadataDir, + MetadataDirAbs: step.MetadataDirAbs, + CommitMessage: step.CommitMessage, + AuthorName: step.AuthorName, + AuthorEmail: step.AuthorEmail, + IsFirstCheckpoint: isFirstCheckpointOfSession, + }) + writeCheckpointSpan.RecordError(err) + writeCheckpointSpan.End() + if err != nil { + return fmt.Errorf("failed to write temporary checkpoint: %w", err) + } - // Log the prompt attribution for debugging - attrLogCtx := logging.WithComponent(ctx, "attribution") - logging.Debug(attrLogCtx, "prompt attribution at checkpoint save", - slog.Int("checkpoint_number", promptAttr.CheckpointNumber), - slog.Int("user_added", promptAttr.UserLinesAdded), - slog.Int("user_removed", promptAttr.UserLinesRemoved), - slog.Int("agent_added", promptAttr.AgentLinesAdded), - slog.Int("agent_removed", promptAttr.AgentLinesRemoved), - slog.String("session_id", sessionID)) - - // Use WriteTemporary to create the checkpoint - _, writeCheckpointSpan := perf.Start(ctx, "write_temporary_checkpoint") - isFirstCheckpointOfSession := state.StepCount == 0 - result, err := store.WriteTemporary(ctx, checkpoint.WriteTemporaryOptions{ - SessionID: sessionID, - BaseCommit: state.BaseCommit, - WorktreeID: state.WorktreeID, - ModifiedFiles: step.ModifiedFiles, - NewFiles: step.NewFiles, - DeletedFiles: step.DeletedFiles, - MetadataDir: step.MetadataDir, - MetadataDirAbs: step.MetadataDirAbs, - CommitMessage: step.CommitMessage, - AuthorName: step.AuthorName, - AuthorEmail: step.AuthorEmail, - IsFirstCheckpoint: isFirstCheckpointOfSession, - }) - writeCheckpointSpan.RecordError(err) - writeCheckpointSpan.End() - if err != nil { - return fmt.Errorf("failed to write temporary checkpoint: %w", err) - } + if result.Skipped { + logCtx := logging.WithComponent(ctx, "checkpoint") + logging.Info(logCtx, "checkpoint skipped (no changes)", + slog.String("strategy", "manual-commit"), + slog.String("checkpoint_type", "session"), + slog.Int("checkpoint_count", state.StepCount), + slog.String("shadow_branch", shadowBranchName), + ) + return ErrMutationSkip + } + + // LastCheckpointID is intentionally NOT cleared here. It is set during + // condensation and used by handleAmendCommitMsg to restore checkpoint + // trailers on amend operations. + state.StepCount++ + state.PromptAttributions = append(state.PromptAttributions, promptAttr) + state.FilesTouched = mergeFilesTouched(state.FilesTouched, step.ModifiedFiles, step.NewFiles, step.DeletedFiles) + if state.StepCount == 1 { + state.TranscriptIdentifierAtStart = step.StepTranscriptIdentifier + } + if step.TokenUsage != nil { + state.TokenUsage = accumulateTokenUsage(state.TokenUsage, step.TokenUsage) + } + + if !branchExisted { + logging.Info(logging.WithComponent(ctx, "checkpoint"), "created shadow branch and committed changes", + slog.String("shadow_branch", shadowBranchName)) + } else { + logging.Info(logging.WithComponent(ctx, "checkpoint"), "committed changes to shadow branch", + slog.String("shadow_branch", shadowBranchName)) + } - // If checkpoint was skipped due to deduplication (no changes), return early - if result.Skipped { logCtx := logging.WithComponent(ctx, "checkpoint") - logging.Info( - logCtx, "checkpoint skipped (no changes)", + logging.Info(logCtx, "checkpoint saved", slog.String("strategy", "manual-commit"), slog.String("checkpoint_type", "session"), slog.Int("checkpoint_count", state.StepCount), + slog.Int("modified_files", len(step.ModifiedFiles)), + slog.Int("new_files", len(step.NewFiles)), + slog.Int("deleted_files", len(step.DeletedFiles)), slog.String("shadow_branch", shadowBranchName), + slog.Bool("branch_created", !branchExisted), ) return nil + }) + if mutErr != nil && !errors.Is(mutErr, ErrMutationSkip) { + return mutErr } + return nil +} - // Update session state - _, updateStateSpan := perf.Start(ctx, "update_session_state") - state.StepCount++ - - // Note: LastCheckpointID is intentionally NOT cleared here. - // It is set during condensation and used by handleAmendCommitMsg - // to restore checkpoint trailers on amend operations. - - // Store the prompt attribution we calculated before saving - state.PromptAttributions = append(state.PromptAttributions, promptAttr) - - // Track touched files (modified, new, and deleted) - state.FilesTouched = mergeFilesTouched(state.FilesTouched, step.ModifiedFiles, step.NewFiles, step.DeletedFiles) - - // On first checkpoint, record the transcript identifier for this session - if state.StepCount == 1 { - state.TranscriptIdentifierAtStart = step.StepTranscriptIdentifier - } - - // Accumulate token usage - if step.TokenUsage != nil { - state.TokenUsage = accumulateTokenUsage(state.TokenUsage, step.TokenUsage) +// ensureSessionInitialized makes sure a session state file exists for the +// given session ID. Called outside MutateSessionState because the helper bails +// with ErrStateNotFound on missing state — initialization establishes the file +// the helper will then mutate under lock. +func (s *ManualCommitStrategy) ensureSessionInitialized(ctx context.Context, repo *git.Repository, sessionID string, agentTypeHint types.AgentType) error { + state, err := s.loadSessionState(ctx, sessionID) + if err != nil { + return fmt.Errorf("failed to load session state: %w", err) } - - // Save updated state - if err := s.saveSessionState(ctx, state); err != nil { - updateStateSpan.RecordError(err) - updateStateSpan.End() - return fmt.Errorf("failed to save session state: %w", err) + if state != nil && state.BaseCommit != "" { + return nil // already initialized } - updateStateSpan.End() - - if !branchExisted { - logging.Info(logging.WithComponent(ctx, "checkpoint"), "created shadow branch and committed changes", - slog.String("shadow_branch", shadowBranchName)) - } else { - logging.Info(logging.WithComponent(ctx, "checkpoint"), "committed changes to shadow branch", - slog.String("shadow_branch", shadowBranchName)) + agentType := resolveAgentType(agentTypeHint, state) + if _, err := s.initializeSession(ctx, repo, sessionID, agentType, "", "", ""); err != nil { + return fmt.Errorf("failed to initialize session: %w", err) } - - // Log checkpoint creation - logCtx := logging.WithComponent(ctx, "checkpoint") - logging.Info( - logCtx, "checkpoint saved", - slog.String("strategy", "manual-commit"), - slog.String("checkpoint_type", "session"), - slog.Int("checkpoint_count", state.StepCount), - slog.Int("modified_files", len(step.ModifiedFiles)), - slog.Int("new_files", len(step.NewFiles)), - slog.Int("deleted_files", len(step.DeletedFiles)), - slog.String("shadow_branch", shadowBranchName), - slog.Bool("branch_created", !branchExisted), - ) - return nil } @@ -195,125 +176,111 @@ func (s *ManualCommitStrategy) SaveTaskStep(ctx context.Context, step TaskStepCo return fmt.Errorf("failed to open git repository: %w", err) } - // Load session state - state, err := s.loadSessionState(ctx, step.SessionID) - if err != nil || state == nil || state.BaseCommit == "" { - agentType := resolveAgentType(step.AgentType, state) - state, err = s.initializeSession(ctx, repo, step.SessionID, agentType, "", "", "") // No transcript/prompt/model in fallback - if err != nil { - return fmt.Errorf("failed to initialize session for task checkpoint: %w", err) - } - } - - // Check if HEAD has changed (e.g., Claude did a rebase via tool call) and migrate if needed - if err := s.migrateAndPersistIfNeeded(ctx, repo, state); err != nil { + if err := s.ensureSessionInitialized(ctx, repo, step.SessionID, step.AgentType); err != nil { return err } - // Get checkpoint store - store, err := s.getCheckpointStore() - if err != nil { - return fmt.Errorf("failed to get checkpoint store: %w", err) - } + mutErr := MutateSessionState(ctx, step.SessionID, func(state *SessionState) error { + if _, _, migrateErr := s.migrateShadowBranchIfNeeded(ctx, repo, state); migrateErr != nil { + return fmt.Errorf("failed to check/migrate shadow branch: %w", migrateErr) + } - // Check if shadow branch exists to report whether we created it - shadowBranchName := checkpoint.ShadowBranchNameForCommit(state.BaseCommit, state.WorktreeID) - branchExisted := store.ShadowBranchExists(state.BaseCommit, state.WorktreeID) + store, storeErr := s.getCheckpointStore() + if storeErr != nil { + return fmt.Errorf("failed to get checkpoint store: %w", storeErr) + } - // Compute metadata paths for commit message - sessionMetadataDir := paths.SessionMetadataDirFromSessionID(step.SessionID) - taskMetadataDir := TaskMetadataDir(sessionMetadataDir, step.ToolUseID) + shadowBranchName := checkpoint.ShadowBranchNameForCommit(state.BaseCommit, state.WorktreeID) + branchExisted := store.ShadowBranchExists(state.BaseCommit, state.WorktreeID) - // Generate commit message - shortToolUseID := step.ToolUseID - if len(shortToolUseID) > id.ShortIDLength { - shortToolUseID = shortToolUseID[:id.ShortIDLength] - } + sessionMetadataDir := paths.SessionMetadataDirFromSessionID(step.SessionID) + taskMetadataDir := TaskMetadataDir(sessionMetadataDir, step.ToolUseID) - var messageSubject string - if step.IsIncremental { - messageSubject = FormatIncrementalSubject( - step.IncrementalType, - step.SubagentType, - step.TaskDescription, - step.TodoContent, - step.IncrementalSequence, - shortToolUseID, + shortToolUseID := step.ToolUseID + if len(shortToolUseID) > id.ShortIDLength { + shortToolUseID = shortToolUseID[:id.ShortIDLength] + } + + var messageSubject string + if step.IsIncremental { + messageSubject = FormatIncrementalSubject( + step.IncrementalType, + step.SubagentType, + step.TaskDescription, + step.TodoContent, + step.IncrementalSequence, + shortToolUseID, + ) + } else { + messageSubject = FormatSubagentEndMessage(step.SubagentType, step.TaskDescription, shortToolUseID) + } + commitMsg := trailers.FormatShadowTaskCommit( + messageSubject, + taskMetadataDir, + step.SessionID, ) - } else { - messageSubject = FormatSubagentEndMessage(step.SubagentType, step.TaskDescription, shortToolUseID) - } - commitMsg := trailers.FormatShadowTaskCommit( - messageSubject, - taskMetadataDir, - step.SessionID, - ) - - // Use WriteTemporaryTask to create the checkpoint - _, err = store.WriteTemporaryTask(ctx, checkpoint.WriteTemporaryTaskOptions{ - SessionID: step.SessionID, - BaseCommit: state.BaseCommit, - WorktreeID: state.WorktreeID, - ToolUseID: step.ToolUseID, - AgentID: step.AgentID, - ModifiedFiles: step.ModifiedFiles, - NewFiles: step.NewFiles, - DeletedFiles: step.DeletedFiles, - TranscriptPath: step.TranscriptPath, - SubagentTranscriptPath: step.SubagentTranscriptPath, - CheckpointUUID: step.CheckpointUUID, - CommitMessage: commitMsg, - AuthorName: step.AuthorName, - AuthorEmail: step.AuthorEmail, - IsIncremental: step.IsIncremental, - IncrementalSequence: step.IncrementalSequence, - IncrementalType: step.IncrementalType, - IncrementalData: step.IncrementalData, - }) - if err != nil { - return fmt.Errorf("failed to write task checkpoint: %w", err) - } - // Track touched files (modified, new, and deleted) - state.FilesTouched = mergeFilesTouched(state.FilesTouched, step.ModifiedFiles, step.NewFiles, step.DeletedFiles) + _, writeErr := store.WriteTemporaryTask(ctx, checkpoint.WriteTemporaryTaskOptions{ + SessionID: step.SessionID, + BaseCommit: state.BaseCommit, + WorktreeID: state.WorktreeID, + ToolUseID: step.ToolUseID, + AgentID: step.AgentID, + ModifiedFiles: step.ModifiedFiles, + NewFiles: step.NewFiles, + DeletedFiles: step.DeletedFiles, + TranscriptPath: step.TranscriptPath, + SubagentTranscriptPath: step.SubagentTranscriptPath, + CheckpointUUID: step.CheckpointUUID, + CommitMessage: commitMsg, + AuthorName: step.AuthorName, + AuthorEmail: step.AuthorEmail, + IsIncremental: step.IsIncremental, + IncrementalSequence: step.IncrementalSequence, + IncrementalType: step.IncrementalType, + IncrementalData: step.IncrementalData, + }) + if writeErr != nil { + return fmt.Errorf("failed to write task checkpoint: %w", writeErr) + } - // Save updated state - if err := s.saveSessionState(ctx, state); err != nil { - return fmt.Errorf("failed to save session state: %w", err) - } + state.FilesTouched = mergeFilesTouched(state.FilesTouched, step.ModifiedFiles, step.NewFiles, step.DeletedFiles) - if !branchExisted { - logging.Info(logging.WithComponent(ctx, "checkpoint"), "created shadow branch and committed task checkpoint", - slog.String("shadow_branch", shadowBranchName)) - } else { - logging.Info(logging.WithComponent(ctx, "checkpoint"), "committed task checkpoint to shadow branch", - slog.String("shadow_branch", shadowBranchName)) - } + if !branchExisted { + logging.Info(logging.WithComponent(ctx, "checkpoint"), "created shadow branch and committed task checkpoint", + slog.String("shadow_branch", shadowBranchName)) + } else { + logging.Info(logging.WithComponent(ctx, "checkpoint"), "committed task checkpoint to shadow branch", + slog.String("shadow_branch", shadowBranchName)) + } - // Log task checkpoint creation - logCtx := logging.WithComponent(ctx, "checkpoint") - attrs := []any{ - slog.String("strategy", "manual-commit"), - slog.String("checkpoint_type", "task"), - slog.String("checkpoint_uuid", step.CheckpointUUID), - slog.String("tool_use_id", step.ToolUseID), - slog.String("subagent_type", step.SubagentType), - slog.Int("modified_files", len(step.ModifiedFiles)), - slog.Int("new_files", len(step.NewFiles)), - slog.Int("deleted_files", len(step.DeletedFiles)), - slog.String("shadow_branch", shadowBranchName), - slog.Bool("branch_created", !branchExisted), - } - if step.IsIncremental { - attrs = append( - attrs, - slog.Bool("is_incremental", true), - slog.String("incremental_type", step.IncrementalType), - slog.Int("incremental_sequence", step.IncrementalSequence), - ) - } - logging.Info(logCtx, "task checkpoint saved", attrs...) + logCtx := logging.WithComponent(ctx, "checkpoint") + attrs := []any{ + slog.String("strategy", "manual-commit"), + slog.String("checkpoint_type", "task"), + slog.String("checkpoint_uuid", step.CheckpointUUID), + slog.String("tool_use_id", step.ToolUseID), + slog.String("subagent_type", step.SubagentType), + slog.Int("modified_files", len(step.ModifiedFiles)), + slog.Int("new_files", len(step.NewFiles)), + slog.Int("deleted_files", len(step.DeletedFiles)), + slog.String("shadow_branch", shadowBranchName), + slog.Bool("branch_created", !branchExisted), + } + if step.IsIncremental { + attrs = append(attrs, + slog.Bool("is_incremental", true), + slog.String("incremental_type", step.IncrementalType), + slog.Int("incremental_sequence", step.IncrementalSequence), + ) + } + logging.Info(logCtx, "task checkpoint saved", attrs...) + return nil + }) + if mutErr != nil { + return mutErr + } return nil } diff --git a/cmd/trace/cli/strategy/session_state.go b/cmd/trace/cli/strategy/session_state.go index 5b6f7cf..8415194 100644 --- a/cmd/trace/cli/strategy/session_state.go +++ b/cmd/trace/cli/strategy/session_state.go @@ -7,10 +7,15 @@ import ( "log/slog" "os" "path/filepath" + "runtime" + "slices" + "strconv" "strings" + "sync" "github.com/GrayCodeAI/trace/cmd/trace/cli/agent" "github.com/GrayCodeAI/trace/cmd/trace/cli/agent/types" + "github.com/GrayCodeAI/trace/cmd/trace/cli/internal/flock" "github.com/GrayCodeAI/trace/cmd/trace/cli/jsonutil" "github.com/GrayCodeAI/trace/cmd/trace/cli/logging" "github.com/GrayCodeAI/trace/cmd/trace/cli/paths" @@ -381,3 +386,200 @@ func LoadAgentTypeHint(ctx context.Context, sessionID string) types.AgentType { } return types.AgentType(strings.TrimSpace(string(data))) } + +// sessionMutationGate provides per-process serialization layered over the +// OS-level flock so that nested MutateSessionState calls in the same +// goroutine don't deadlock or lose updates. POSIX flock isn't reentrant +// across distinct file descriptors in the same process; on top of that, a +// nested call that did its own load → save would have its save overwritten +// by the outer save. The gate fixes both: nested calls in the same +// goroutine reuse the outer's state pointer (no second load, no second +// save), and only the outermost release drops the flock. +var sessionMutationGate sync.Map // map[string]*sessionGate + +type sessionGate struct { + mu sync.Mutex + owner int64 // goroutine ID of the current holder, 0 when unlocked + depth int + flockRel func() + activeState *SessionState // shared state pointer for nested mutations +} + +// goroutineID extracts the runtime goroutine ID from the stack header. Used +// only as a reentrancy key for the session mutation gate — never as a +// security boundary or for application logic. +func goroutineID() int64 { + var buf [64]byte + n := runtime.Stack(buf[:], false) + const prefix = "goroutine " + s := string(buf[:n]) + if !strings.HasPrefix(s, prefix) { + return -1 + } + s = s[len(prefix):] + end := strings.IndexByte(s, ' ') + if end < 0 { + return -1 + } + id, err := strconv.ParseInt(s[:end], 10, 64) + if err != nil { + return -1 + } + return id +} + +// ErrMutationSkip signals MutateSessionState to skip the save without +// treating fn's return as an error. Use it when the mutation function +// observes the loaded state and decides no write is needed. +var ErrMutationSkip = errors.New("session state mutation skipped") + +// ErrStateNotFound is returned by MutateSessionState when no state file +// exists for the session ID (typically because the event arrived before +// InitializeSession ran). +var ErrStateNotFound = errors.New("session state not found") + +// MutateSessionState is the safe load → mutate → save helper. It takes an +// OS-level advisory lock against .git/trace-session-locks/.lock for the +// duration of the read+write so concurrent processes cannot lose each +// other's updates. fn receives the freshly-loaded state and mutates it in +// place; returning ErrMutationSkip skips the save. Reentrant within the same +// goroutine: nested calls share the outer's state pointer and skip the +// inner load/save, so all mutations are flushed by the outermost call. +// +// Returns ErrStateNotFound if the state file doesn't exist (event arrived +// before InitializeSession). Errors from fn or from load/save propagate. +func MutateSessionState(ctx context.Context, sessionID string, fn func(*SessionState) error) error { + if sessionID == "" { + return ErrStateNotFound + } + gate, isOuter, release, err := acquireSessionGate(ctx, sessionID) + if err != nil { + return err + } + defer release() + + if !isOuter { + // Nested call: reuse the outer's state pointer. + if gate.activeState == nil { + return ErrStateNotFound + } + if err := fn(gate.activeState); err != nil && !errors.Is(err, ErrMutationSkip) { + return err + } + return nil + } + + state, err := LoadSessionState(ctx, sessionID) + if err != nil { + return fmt.Errorf("load session state: %w", err) + } + if state == nil { + return ErrStateNotFound + } + gate.activeState = state + defer func() { gate.activeState = nil }() + + if err := fn(state); err != nil { + if errors.Is(err, ErrMutationSkip) { + return nil + } + return err + } + if err := SaveSessionState(ctx, state); err != nil { + return fmt.Errorf("save session state: %w", err) + } + return nil +} + +// acquireSessionGate takes the per-process gate (in-memory) and, on the +// outermost call, the cross-process flock. +func acquireSessionGate(ctx context.Context, sessionID string) (gate *sessionGate, isOuter bool, release func(), err error) { + val, _ := sessionMutationGate.LoadOrStore(sessionID, &sessionGate{}) + gate, ok := val.(*sessionGate) + if !ok { + return nil, false, nil, fmt.Errorf("session gate type assertion failed for %s", sessionID) + } + + gid := goroutineID() + gate.mu.Lock() + if gate.owner == gid { + gate.depth++ + gate.mu.Unlock() + return gate, false, func() { + gate.mu.Lock() + gate.depth-- + gate.mu.Unlock() + }, nil + } + gate.mu.Unlock() + + lockPath, err := stateLockPath(ctx, sessionID) + if err != nil { + return nil, false, nil, fmt.Errorf("resolve state lock path: %w", err) + } + flockRel, err := flock.Acquire(lockPath) + if err != nil { + return nil, false, nil, fmt.Errorf("acquire state lock: %w", err) + } + + gate.mu.Lock() + gate.owner = gid + gate.depth = 1 + gate.flockRel = flockRel + gate.mu.Unlock() + + return gate, true, func() { + gate.mu.Lock() + gate.depth-- + if gate.depth == 0 { + rel := gate.flockRel + gate.flockRel = nil + gate.owner = 0 + gate.mu.Unlock() + rel() + return + } + gate.mu.Unlock() + }, nil +} + +// stateLockPath returns the lock file path for a session. Lock files live in +// .git/trace-session-locks/ (a sibling to trace-sessions/) so callers that +// enumerate session state files don't have to filter lock entries. +func stateLockPath(ctx context.Context, sessionID string) (string, error) { + if err := validation.ValidateSessionID(sessionID); err != nil { + return "", fmt.Errorf("invalid session ID: %w", err) + } + commonDir, err := GetGitCommonDir(ctx) + if err != nil { + return "", err + } + lockDir := filepath.Join(commonDir, "trace-session-locks") + if err := os.MkdirAll(lockDir, 0o750); err != nil { + return "", fmt.Errorf("create session lock directory: %w", err) + } + return filepath.Join(lockDir, sessionID+".lock"), nil +} + +// RecordFilesTouched merges paths into the session's FilesTouched, used by +// mid-turn lifecycle events (per-tool-use hooks) so PostCommit's carry-forward +// decision sees an accurate file list. Caller must pre-normalize paths to +// repo-relative form. No-ops when the session state doesn't exist or the +// merge produced no changes. +func RecordFilesTouched(ctx context.Context, sessionID string, modified, added, deleted []string) error { + if len(modified) == 0 && len(added) == 0 && len(deleted) == 0 { + return nil + } + err := MutateSessionState(ctx, sessionID, func(state *SessionState) error { + merged := mergeFilesTouched(state.FilesTouched, modified, added, deleted) + if slices.Equal(merged, state.FilesTouched) { + return ErrMutationSkip + } + state.FilesTouched = merged + return nil + }) + if errors.Is(err, ErrStateNotFound) { + return nil + } + return err +} diff --git a/cmd/trace/cli/strategy/session_state_test.go b/cmd/trace/cli/strategy/session_state_test.go index 171c9fd..092b23e 100644 --- a/cmd/trace/cli/strategy/session_state_test.go +++ b/cmd/trace/cli/strategy/session_state_test.go @@ -581,6 +581,344 @@ func TestLoadModelHint_TrimsWhitespace(t *testing.T) { } } +// --- MutateSessionState tests --- + +func TestMutateSessionState_BasicMutation(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "mutate-basic", + BaseCommit: "abc123", + StartedAt: time.Now(), + StepCount: 1, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + err = MutateSessionState(ctx, "mutate-basic", func(s *SessionState) error { + s.StepCount = 5 + return nil + }) + if err != nil { + t.Fatalf("MutateSessionState() error = %v", err) + } + + loaded, err := LoadSessionState(ctx, "mutate-basic") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + if loaded.StepCount != 5 { + t.Errorf("StepCount = %d, want 5", loaded.StepCount) + } +} + +func TestMutateSessionState_SkipSave(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "mutate-skip", + BaseCommit: "abc123", + StartedAt: time.Now(), + StepCount: 3, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + err = MutateSessionState(ctx, "mutate-skip", func(s *SessionState) error { + s.StepCount = 999 + return ErrMutationSkip + }) + if err != nil { + t.Fatalf("MutateSessionState() with ErrMutationSkip error = %v", err) + } + + loaded, err := LoadSessionState(ctx, "mutate-skip") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + if loaded.StepCount != 3 { + t.Errorf("StepCount = %d, want 3 (skip should not save)", loaded.StepCount) + } +} + +func TestMutateSessionState_NotFound(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + err = MutateSessionState(context.Background(), "nonexistent", func(s *SessionState) error { + return nil + }) + if !errors.Is(err, ErrStateNotFound) { + t.Errorf("MutateSessionState() error = %v, want ErrStateNotFound", err) + } +} + +func TestMutateSessionState_EmptySessionID(t *testing.T) { + t.Parallel() + err := MutateSessionState(context.Background(), "", func(s *SessionState) error { + return nil + }) + if !errors.Is(err, ErrStateNotFound) { + t.Errorf("MutateSessionState('') error = %v, want ErrStateNotFound", err) + } +} + +func TestMutateSessionState_NestedCallsShareState(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "mutate-nested", + BaseCommit: "abc123", + StartedAt: time.Now(), + StepCount: 1, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + err = MutateSessionState(ctx, "mutate-nested", func(outer *SessionState) error { + outer.StepCount = 10 + + // Nested call should see the outer's mutation and not deadlock. + return MutateSessionState(ctx, "mutate-nested", func(inner *SessionState) error { + if inner.StepCount != 10 { + t.Errorf("nested StepCount = %d, want 10 (should see outer mutation)", inner.StepCount) + } + inner.CheckpointTranscriptStart = 42 + return nil + }) + }) + if err != nil { + t.Fatalf("MutateSessionState() nested error = %v", err) + } + + loaded, err := LoadSessionState(ctx, "mutate-nested") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + if loaded.StepCount != 10 { + t.Errorf("StepCount = %d, want 10", loaded.StepCount) + } + if loaded.CheckpointTranscriptStart != 42 { + t.Errorf("CheckpointTranscriptStart = %d, want 42", loaded.CheckpointTranscriptStart) + } +} + +func TestMutateSessionState_FnError(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "mutate-fn-err", + BaseCommit: "abc123", + StartedAt: time.Now(), + StepCount: 1, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + sentinel := errors.New("mutation failed") + err = MutateSessionState(ctx, "mutate-fn-err", func(s *SessionState) error { + s.StepCount = 999 + return sentinel + }) + if !errors.Is(err, sentinel) { + t.Errorf("MutateSessionState() error = %v, want %v", err, sentinel) + } + + // State should NOT have been saved (fn returned error before save). + loaded, err := LoadSessionState(ctx, "mutate-fn-err") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + if loaded.StepCount != 1 { + t.Errorf("StepCount = %d, want 1 (error should prevent save)", loaded.StepCount) + } +} + +// --- RecordFilesTouched tests --- + +func TestRecordFilesTouched_MergesIntoState(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "rft-merge", + BaseCommit: "abc123", + StartedAt: time.Now(), + FilesTouched: []string{"existing.go"}, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + err = RecordFilesTouched(ctx, "rft-merge", + []string{"modified.go"}, + []string{"added.go"}, + []string{"deleted.go"}, + ) + if err != nil { + t.Fatalf("RecordFilesTouched() error = %v", err) + } + + loaded, err := LoadSessionState(ctx, "rft-merge") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + + expected := map[string]bool{ + "existing.go": true, + "modified.go": true, + "added.go": true, + "deleted.go": true, + } + got := make(map[string]bool) + for _, f := range loaded.FilesTouched { + got[f] = true + } + for f := range expected { + if !got[f] { + t.Errorf("FilesTouched missing %q, got %v", f, loaded.FilesTouched) + } + } +} + +func TestRecordFilesTouched_EmptyInputs_NoOp(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "rft-empty", + BaseCommit: "abc123", + StartedAt: time.Now(), + FilesTouched: []string{"file.go"}, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + err = RecordFilesTouched(ctx, "rft-empty", nil, nil, nil) + if err != nil { + t.Fatalf("RecordFilesTouched() with empty inputs error = %v", err) + } + + // State should be unchanged. + loaded, err := LoadSessionState(ctx, "rft-empty") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + if len(loaded.FilesTouched) != 1 || loaded.FilesTouched[0] != "file.go" { + t.Errorf("FilesTouched = %v, want [file.go]", loaded.FilesTouched) + } +} + +func TestRecordFilesTouched_NotFound_NoOp(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + // Should not error when session doesn't exist. + err = RecordFilesTouched(context.Background(), "nonexistent", + []string{"file.go"}, nil, nil) + if err != nil { + t.Fatalf("RecordFilesTouched() for nonexistent session error = %v, want nil", err) + } +} + +func TestRecordFilesTouched_Deduplicates(t *testing.T) { + dir := t.TempDir() + _, err := git.PlainInit(dir, false) + if err != nil { + t.Fatalf("failed to init git repo: %v", err) + } + t.Chdir(dir) + + ctx := context.Background() + state := &SessionState{ + SessionID: "rft-dedup", + BaseCommit: "abc123", + StartedAt: time.Now(), + FilesTouched: []string{"file.go"}, + } + if err := SaveSessionState(ctx, state); err != nil { + t.Fatalf("SaveSessionState() error = %v", err) + } + + // Same file in modified list should not create a duplicate. + err = RecordFilesTouched(ctx, "rft-dedup", + []string{"file.go"}, nil, nil) + if err != nil { + t.Fatalf("RecordFilesTouched() error = %v", err) + } + + loaded, err := LoadSessionState(ctx, "rft-dedup") + if err != nil { + t.Fatalf("LoadSessionState() error = %v", err) + } + count := 0 + for _, f := range loaded.FilesTouched { + if f == "file.go" { + count++ + } + } + if count != 1 { + t.Errorf("file.go appears %d times, want 1 in %v", count, loaded.FilesTouched) + } +} + +// --- goroutineID test --- + +func TestGoroutineID_ReturnsPositive(t *testing.T) { + t.Parallel() + id := goroutineID() + if id <= 0 { + t.Errorf("goroutineID() = %d, want > 0", id) + } +} + func TestClearSessionState_RemovesHintFile(t *testing.T) { dir := t.TempDir() _, err := git.PlainInit(dir, false) diff --git a/cmd/trace/cli/strategy/state_lock_unix.go b/cmd/trace/cli/strategy/state_lock_unix.go deleted file mode 100644 index 955c99e..0000000 --- a/cmd/trace/cli/strategy/state_lock_unix.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build unix - -package strategy - -import ( - "fmt" - "os" - "syscall" -) - -// acquireStateFileLock takes an exclusive POSIX advisory lock on path. The -// returned release closes the file (which drops the flock). Callers must call -// release exactly once. The lock file persists between runs — that's fine, -// flock state is held by the file descriptor, not the inode on disk. -// -//nolint:unused // platform-specific implementation; will be called once state locking is wired in -func acquireStateFileLock(path string) (release func(), err error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) //nolint:gosec // path built from validated session ID - if err != nil { - return nil, fmt.Errorf("open state lock: %w", err) - } - if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { //nolint:gosec // file descriptors are non-negative; standard Go pattern for syscall.Flock - _ = f.Close() - return nil, fmt.Errorf("flock state lock: %w", err) - } - return func() { _ = f.Close() }, nil -} diff --git a/cmd/trace/cli/trail_cmd.go b/cmd/trace/cli/trail_cmd.go index 633f482..6bf4715 100644 --- a/cmd/trace/cli/trail_cmd.go +++ b/cmd/trace/cli/trail_cmd.go @@ -10,12 +10,12 @@ import ( "os/exec" "sort" "strings" + "text/tabwriter" "time" "github.com/GrayCodeAI/trace/cmd/trace/cli/api" "github.com/GrayCodeAI/trace/cmd/trace/cli/gitremote" "github.com/GrayCodeAI/trace/cmd/trace/cli/strategy" - "github.com/GrayCodeAI/trace/cmd/trace/cli/stringutil" "github.com/GrayCodeAI/trace/cmd/trace/cli/trail" "charm.land/huh/v2" @@ -24,6 +24,33 @@ import ( "github.com/spf13/cobra" ) +const ( + defaultTrailListLimit = 10 + trailListAuthorMe = "me" + defaultTrailListStatus = string(trail.StatusInProgress) + // trailListStatusAny disables the status filter; user-facing value for --status. + trailListStatusAny = "any" +) + +// trailListOptions are the inputs to runTrailListAll. Keeping them on a +// struct avoids a long positional argument list at the two call sites. +type trailListOptions struct { + Author string + Status string + JSON bool + Limit int + InsecureHTTP bool + ShowAll bool +} + +func defaultTrailListOptions(insecureHTTP bool) trailListOptions { + return trailListOptions{ + Status: defaultTrailListStatus, + Limit: defaultTrailListLimit, + InsecureHTTP: insecureHTTP, + } +} + func newTrailCmd() *cobra.Command { var insecureHTTPAuth bool @@ -65,7 +92,7 @@ func trailInsecureHTTP(cmd *cobra.Command) bool { func runTrailShow(ctx context.Context, w io.Writer, insecureHTTP bool) error { branch, err := GetCurrentBranch(ctx) if err != nil { - return runTrailListAll(ctx, w, "", false, false, insecureHTTP) + return runTrailListAll(ctx, w, defaultTrailListOptions(insecureHTTP)) } client, err := NewAuthenticatedAPIClient(insecureHTTP) @@ -83,7 +110,7 @@ func runTrailShow(ctx context.Context, w io.Writer, insecureHTTP bool) error { return err } if found == nil { - return runTrailListAll(ctx, w, "", false, false, insecureHTTP) + return runTrailListAll(ctx, w, defaultTrailListOptions(insecureHTTP)) } printTrailDetails(w, found.ToMetadata()) @@ -92,7 +119,12 @@ func runTrailShow(ctx context.Context, w io.Writer, insecureHTTP bool) error { func printTrailDetails(w io.Writer, m *trail.Metadata) { fmt.Fprintf(w, "Trail: %s\n", m.Title) - fmt.Fprintf(w, " ID: %s\n", m.TrailID) + if m.Number > 0 { + fmt.Fprintf(w, " Number: %d\n", m.Number) + } + if !m.TrailID.IsEmpty() { + fmt.Fprintf(w, " ID: %s\n", m.TrailID) + } fmt.Fprintf(w, " Branch: %s\n", m.Branch) fmt.Fprintf(w, " Base: %s\n", m.Base) fmt.Fprintf(w, " Status: %s\n", m.Status) @@ -111,27 +143,44 @@ func printTrailDetails(w io.Writer, m *trail.Metadata) { } func newTrailListCmd() *cobra.Command { - var statusFilter string - var jsonOutput bool - var showAll bool + var opts trailListOptions cmd := &cobra.Command{ Use: "list", - Short: "List all trails", + Short: "List recent trails", RunE: func(cmd *cobra.Command, _ []string) error { - return runTrailListAll(cmd.Context(), cmd.OutOrStdout(), statusFilter, jsonOutput, showAll, trailInsecureHTTP(cmd)) + opts.InsecureHTTP = trailInsecureHTTP(cmd) + return runTrailListAll(cmd.Context(), cmd.OutOrStdout(), opts) }, } - cmd.Flags().StringVar(&statusFilter, "status", "", "Filter by status (draft, open, in_progress, in_review, merged, closed)") - cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") - cmd.Flags().BoolVarP(&showAll, "all", "a", false, "Include merged and closed trails") + cmd.Flags().StringVar(&opts.Author, "author", "", + "Filter by author login (case-insensitive); use '"+trailListAuthorMe+"' for yourself (requires gh CLI); omit for any author") + cmd.Flags().StringVar(&opts.Status, "status", defaultTrailListStatus, + "Filter by comma-separated status(es): "+formatValidStatuses()+"; use '"+trailListStatusAny+"' for all statuses") + cmd.Flags().BoolVar(&opts.JSON, "json", false, "Output as JSON (respects --status filter)") + cmd.Flags().IntVarP(&opts.Limit, "limit", "n", defaultTrailListLimit, "Maximum number of trails to show") + cmd.Flags().BoolVarP(&opts.ShowAll, "all", "a", false, "Show all statuses (equivalent to --status any)") return cmd } -func runTrailListAll(ctx context.Context, w io.Writer, statusFilter string, jsonOutput, showAll, insecureHTTP bool) error { - client, err := NewAuthenticatedAPIClient(insecureHTTP) +func runTrailListAll(ctx context.Context, w io.Writer, opts trailListOptions) error { + if opts.Limit <= 0 { + return errors.New("limit must be greater than 0") + } + + // When --all is set, treat it as --status any. + statusSource := opts.Status + if opts.ShowAll { + statusSource = trailListStatusAny + } + statusFilters, err := parseTrailStatusFilter(statusSource) + if err != nil { + return err + } + + client, err := NewAuthenticatedAPIClient(opts.InsecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -161,38 +210,32 @@ func runTrailListAll(ctx context.Context, w io.Writer, statusFilter string, json trails = append(trails, listResp.Trails[i].ToMetadata()) } - totalCount := len(trails) - - // Apply status filter - if statusFilter != "" { - status := trail.Status(statusFilter) - if !status.IsValid() { - return fmt.Errorf("invalid status %q: valid values are %s", statusFilter, formatValidStatuses()) - } - var filtered []*trail.Metadata - for _, t := range trails { - if t.Status == status { - filtered = append(filtered, t) - } + authorFilter := opts.Author + currentUserLogin := "" + if authorFilter == trailListAuthorMe { + login, loginErr := fetchCurrentUserLogin(ctx, execRunner{}) + if loginErr != nil { + return loginErr } - trails = filtered - } else if !showAll { - // By default, hide merged and closed trails - var filtered []*trail.Metadata - for _, t := range trails { - if t.Status != trail.StatusMerged && t.Status != trail.StatusClosed { - filtered = append(filtered, t) - } - } - trails = filtered + currentUserLogin = login + authorFilter = login + } + + if authorFilter != "" { + trails = filterTrailsByAuthor(trails, authorFilter) } - // Sort by updated_at descending + if len(statusFilters) > 0 { + trails = filterTrailsByStatuses(trails, statusFilters) + } + + // Sort by updated_at descending, then keep only the most recent rows. sort.Slice(trails, func(i, j int) bool { return trails[i].UpdatedAt.After(trails[j].UpdatedAt) }) + trails = limitTrails(trails, opts.Limit) - if jsonOutput { + if opts.JSON { enc := json.NewEncoder(w) enc.SetIndent("", " ") if err := enc.Encode(trails); err != nil { @@ -202,30 +245,257 @@ func runTrailListAll(ctx context.Context, w io.Writer, statusFilter string, json } if len(trails) == 0 { - hiddenCount := totalCount - len(trails) - if hiddenCount > 0 { - fmt.Fprintf(w, "No active trails found. %d merged/closed trail(s) hidden — use --all to show.\n", hiddenCount) - } else { - fmt.Fprintln(w, "No trails found.") + fmt.Fprintln(w, "No trails found.") + fmt.Fprintln(w) + fmt.Fprintln(w, "Commands:") + fmt.Fprintln(w, " trace trail create Create a trail for the current branch") + fmt.Fprintln(w, " trace trail list List recent trails") + fmt.Fprintln(w, " trace trail update Update trail metadata") + return nil + } + + printTrailList(w, trails, trailListDisplayOptions{ + RequestedAuthor: authorFilter, + CurrentUser: currentUserLogin, + StatusFilters: statusFilters, + }) + + return nil +} + +func limitTrails(trails []*trail.Metadata, limit int) []*trail.Metadata { + if len(trails) <= limit { + return trails + } + return trails[:limit] +} + +// filterTrailsByAuthor matches case-insensitively because GitHub logins are +// case-insensitive (e.g. "Alice" and "alice" identify the same user). +func filterTrailsByAuthor(trails []*trail.Metadata, login string) []*trail.Metadata { + var filtered []*trail.Metadata + for _, t := range trails { + if strings.EqualFold(t.AuthorLogin(), login) { + filtered = append(filtered, t) + } + } + return filtered +} + +func filterTrailsByStatuses(trails []*trail.Metadata, statuses []trail.Status) []*trail.Metadata { + statusSet := make(map[trail.Status]bool, len(statuses)) + for _, status := range statuses { + statusSet[status] = true + } + + var filtered []*trail.Metadata + for _, t := range trails { + if statusSet[t.Status] { + filtered = append(filtered, t) + } + } + return filtered +} + +func parseTrailStatusFilter(filter string) ([]trail.Status, error) { + if filter == "" || filter == trailListStatusAny { + return nil, nil + } + + parts := strings.Split(filter, ",") + statuses := make([]trail.Status, 0, len(parts)) + seen := make(map[trail.Status]bool, len(parts)) + for _, part := range parts { + name := strings.TrimSpace(part) + if name == "" { + return nil, fmt.Errorf("invalid status filter %q: empty status", filter) + } + status := trail.Status(name) + if !status.IsValid() { + return nil, fmt.Errorf("invalid status %q: valid values are %s", name, formatValidStatuses()) + } + if seen[status] { + continue + } + seen[status] = true + statuses = append(statuses, status) + } + return statuses, nil +} + +// fetchCurrentUserLogin resolves --author me to a GitHub login via the local +// gh CLI. The runner is injectable so tests can stub gh without touching the +// process environment. +func fetchCurrentUserLogin(ctx context.Context, runner bootstrapRunner) (string, error) { + login, err := ghCurrentUser(ctx, runner) + if err != nil { + return "", fmt.Errorf("resolve --author %s via gh CLI: %w\nhint: pass --author explicitly if gh is unavailable", trailListAuthorMe, err) + } + if login == "" { + return "", errors.New("resolve --author me: gh returned an empty login") + } + return login, nil +} + +type trailListDisplayOptions struct { + RequestedAuthor string + CurrentUser string + StatusFilters []trail.Status +} + +func printTrailList(w io.Writer, trails []*trail.Metadata, opts trailListDisplayOptions) { + showAuthor := opts.RequestedAuthor == "" + // Group by status when the user filtered for 0 or 2+ statuses. A single + // status is already named in the header, so flat rows read more cleanly. + grouped := len(opts.StatusFilters) != 1 + printTrailListHeader(w, opts, len(trails)) + fmt.Fprintln(w) + if !grouped { + printTrailRows(w, trails, showAuthor) + return + } + + rendered := make(map[*trail.Metadata]bool, len(trails)) + for _, status := range trailListStatusOrder(opts.StatusFilters) { + group := filterTrailsByStatus(trails, status) + if len(group) == 0 { + continue + } + for _, t := range group { + rendered[t] = true + } + fmt.Fprintf(w, " %s · %d\n", trailStatusTitle(status), len(group)) + fmt.Fprintln(w) + printTrailRows(w, group, showAuthor) + fmt.Fprintln(w) + } + + // When no explicit status filter is set, surface trails with unknown + // statuses in an "Other" bucket so they don't silently disappear if the + // server adds a status the CLI hasn't learned about yet. + if len(opts.StatusFilters) == 0 { + var other []*trail.Metadata + for _, t := range trails { + if !rendered[t] { + other = append(other, t) + } + } + if len(other) > 0 { + fmt.Fprintf(w, " Other · %d\n", len(other)) + fmt.Fprintln(w) + printTrailRows(w, other, showAuthor) fmt.Fprintln(w) - fmt.Fprintln(w, "Commands:") - fmt.Fprintln(w, " trace trail create Create a trail for the current branch") - fmt.Fprintln(w, " trace trail list List all trails") - fmt.Fprintln(w, " trace trail update Update trail metadata") } - return nil + } +} + +func printTrailListHeader(w io.Writer, opts trailListDisplayOptions, count int) { + if opts.RequestedAuthor == "" { + if len(opts.StatusFilters) == 0 { + fmt.Fprintf(w, " Recent %s · %d\n", pluralize("trail", count), count) + return + } + fmt.Fprintf(w, " %s · %d %s\n", trailStatusListTitle(opts.StatusFilters), count, pluralize("trail", count)) + return } - // Table output - fmt.Fprintf(w, "%-30s %-40s %-13s %-15s %s\n", "BRANCH", "TITLE", "STATUS", "AUTHOR", "UPDATED") + label := opts.RequestedAuthor + // When --author me resolves to the same login the server already returned + // for the trail, render "Your trails (login)" so identity drift between + // gh and Trace is visible at a glance. + if opts.CurrentUser != "" && strings.EqualFold(opts.RequestedAuthor, opts.CurrentUser) { + label = fmt.Sprintf("Your trails (%s)", opts.CurrentUser) + } + if len(opts.StatusFilters) == 0 { + fmt.Fprintf(w, " %s · %d\n", label, count) + return + } + fmt.Fprintf(w, " %s · %d %s\n", label, count, trailStatusListDisplay(opts.StatusFilters)) +} + +func printTrailRows(w io.Writer, trails []*trail.Metadata, showAuthor bool) { + // tabwriter aligns by display columns instead of bytes, so multi-byte + // branch names or logins don't throw off the table. + tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0) for _, t := range trails { - branch := stringutil.TruncateRunes(t.Branch, 30, "...") - title := stringutil.TruncateRunes(t.Title, 40, "...") - fmt.Fprintf(w, "%-30s %-40s %-13s %-15s %s\n", - branch, title, t.Status, stringutil.TruncateRunes(t.AuthorLogin(), 15, "..."), timeAgo(t.UpdatedAt)) + if showAuthor { + fmt.Fprintf(tw, " %s\t%s\t%s\n", t.Branch, t.AuthorLogin(), timeAgo(t.UpdatedAt)) + continue + } + fmt.Fprintf(tw, " %s\t%s\n", t.Branch, timeAgo(t.UpdatedAt)) } + _ = tw.Flush() +} - return nil +func filterTrailsByStatus(trails []*trail.Metadata, status trail.Status) []*trail.Metadata { + var filtered []*trail.Metadata + for _, t := range trails { + if t.Status == status { + filtered = append(filtered, t) + } + } + return filtered +} + +func trailListStatusOrder(filter []trail.Status) []trail.Status { + order := []trail.Status{ + trail.StatusInProgress, + trail.StatusOpen, + trail.StatusInReview, + trail.StatusDraft, + trail.StatusMerged, + trail.StatusClosed, + } + if len(filter) == 0 { + return order + } + + allowed := make(map[trail.Status]bool, len(filter)) + for _, status := range filter { + allowed[status] = true + } + var filtered []trail.Status + for _, status := range order { + if allowed[status] { + filtered = append(filtered, status) + } + } + return filtered +} + +func trailStatusListDisplay(statuses []trail.Status) string { + parts := make([]string, len(statuses)) + for i, status := range statuses { + parts[i] = trailStatusDisplay(status) + } + return strings.Join(parts, ", ") +} + +func trailStatusListTitle(statuses []trail.Status) string { + display := trailStatusListDisplay(statuses) + if display == "" { + return "" + } + return strings.ToUpper(display[:1]) + display[1:] +} + +func trailStatusDisplay(status trail.Status) string { + return strings.ReplaceAll(string(status), "_", " ") +} + +func trailStatusTitle(status trail.Status) string { + display := trailStatusDisplay(status) + if display == "" { + return "" + } + return strings.ToUpper(display[:1]) + display[1:] +} + +func pluralize(s string, count int) string { + if count == 1 { + return s + } + return s + "s" } func newTrailCreateCmd() *cobra.Command { @@ -350,7 +620,7 @@ func runTrailCreate(cmd *cobra.Command, title, body, base, branch, statusStr str return fmt.Errorf("failed to decode create response: %w", err) } - fmt.Fprintf(w, "Created trail %q for branch %s (ID: %s)\n", createResp.Trail.Title, createResp.Trail.Branch, createResp.Trail.TrailID) + fmt.Fprintf(w, "Created trail %q for branch %s (ID: %s)\n", createResp.Trail.Title, createResp.Trail.Branch, createResp.Trail.ID) // --- Phase 3: Post-creation local operations --- @@ -482,7 +752,7 @@ func runTrailUpdate(ctx context.Context, w, errW io.Writer, insecureHTTP bool, s // Build update request with only changed fields updateReq := buildTrailUpdateRequest(found, statusStr, title, body, labelAdd, labelRemove) - resp, err := client.Patch(ctx, trailsBasePath(host, owner, repoName)+"/"+found.TrailID, updateReq) + resp, err := client.Patch(ctx, trailsBasePath(host, owner, repoName)+"/"+found.ID, updateReq) if err != nil { return fmt.Errorf("failed to update trail: %w", err) } @@ -547,6 +817,12 @@ func buildTrailUpdateRequest(current *api.TrailResource, statusStr, title, body // defaultBaseBranch is the fallback base branch name when it cannot be determined. const defaultBaseBranch = "main" +// masterBaseBranch is the secondary fallback for repos still using "master" +// (pre-git-2.28 defaults, forks of older projects, etc.). Extracted as a +// constant so goconst stays quiet across the several call sites in the cli +// package. +const masterBaseBranch = "master" + func formatValidStatuses() string { statuses := trail.ValidStatuses() names := make([]string, len(statuses)) @@ -619,6 +895,19 @@ func runTrailCreateInteractive(title, body, branch, statusStr *string) error { // findTrailByBranch looks up a trail by branch name via the list API. func findTrailByBranch(ctx context.Context, client *api.Client, host, owner, repo, branch string) (*api.TrailResource, error) { + return findTrail(ctx, client, host, owner, repo, func(t api.TrailResource) bool { + return t.Branch == branch + }) +} + +// findTrailByNumber looks up a trail by numeric identifier via the list API. +func findTrailByNumber(ctx context.Context, client *api.Client, host, owner, repo string, number int) (*api.TrailResource, error) { //nolint:unused // used by trail_watch when number is provided directly + return findTrail(ctx, client, host, owner, repo, func(t api.TrailResource) bool { + return t.Number == number + }) +} + +func findTrail(ctx context.Context, client *api.Client, host, owner, repo string, match func(api.TrailResource) bool) (*api.TrailResource, error) { resp, err := client.Get(ctx, trailsBasePath(host, owner, repo)) if err != nil { return nil, fmt.Errorf("list trails: %w", err) @@ -634,7 +923,7 @@ func findTrailByBranch(ctx context.Context, client *api.Client, host, owner, rep } for i := range listResp.Trails { - if listResp.Trails[i].Branch == branch { + if match(listResp.Trails[i]) { return &listResp.Trails[i], nil } } diff --git a/cmd/trace/cli/trail_cmd_test.go b/cmd/trace/cli/trail_cmd_test.go new file mode 100644 index 0000000..ae34bac --- /dev/null +++ b/cmd/trace/cli/trail_cmd_test.go @@ -0,0 +1,287 @@ +package cli + +import ( + "bytes" + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/GrayCodeAI/trace/cmd/trace/cli/trail" +) + +const ( + trailListTestAuthorAlice = "alice" + trailListTestAuthorBob = "bob" +) + +func TestLimitTrailsKeepsMostRecentPrefix(t *testing.T) { + t.Parallel() + trails := []*trail.Metadata{ + {Branch: "newest"}, + {Branch: "middle"}, + {Branch: "oldest"}, + } + + got := limitTrails(trails, 2) + if len(got) != 2 { + t.Fatalf("len = %d, want 2", len(got)) + } + if got[0].Branch != "newest" || got[1].Branch != "middle" { + t.Fatalf("got branches %q, %q; want newest, middle", got[0].Branch, got[1].Branch) + } + + if all := limitTrails(trails, 3); len(all) != len(trails) { + t.Fatalf("limit 3 len = %d, want %d", len(all), len(trails)) + } +} + +func TestFilterTrailsByAuthor(t *testing.T) { + t.Parallel() + alice := trailListTestAuthorAlice + bob := trailListTestAuthorBob + trails := []*trail.Metadata{ + {Branch: "mine-1", Author: &trail.Author{Login: &alice}}, + {Branch: "theirs", Author: &trail.Author{Login: &bob}}, + {Branch: "unknown"}, + {Branch: "mine-2", Author: &trail.Author{Login: &alice}}, + } + + got := filterTrailsByAuthor(trails, trailListTestAuthorAlice) + if len(got) != 2 { + t.Fatalf("len = %d, want 2", len(got)) + } + if got[0].Branch != "mine-1" || got[1].Branch != "mine-2" { + t.Fatalf("got branches %q, %q; want mine-1, mine-2", got[0].Branch, got[1].Branch) + } +} + +func TestFilterTrailsByAuthorIsCaseInsensitive(t *testing.T) { + t.Parallel() + mixed := "Alice" + trails := []*trail.Metadata{ + {Branch: "mine", Author: &trail.Author{Login: &mixed}}, + } + + got := filterTrailsByAuthor(trails, "alice") + if len(got) != 1 { + t.Fatalf("len = %d, want 1 (case-insensitive)", len(got)) + } +} + +func TestParseTrailStatusFilterAcceptsCommaSeparatedStatuses(t *testing.T) { + t.Parallel() + got, err := parseTrailStatusFilter("in_progress, open,closed") + if err != nil { + t.Fatalf("parseTrailStatusFilter: %v", err) + } + want := []trail.Status{trail.StatusInProgress, trail.StatusOpen, trail.StatusClosed} + if len(got) != len(want) { + t.Fatalf("len = %d, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("status[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestParseTrailStatusFilterRejectsInvalidStatus(t *testing.T) { + t.Parallel() + if _, err := parseTrailStatusFilter("in_progress,nope"); err == nil { + t.Fatal("expected invalid status error") + } +} + +func TestParseTrailStatusFilterAnySentinelMeansNoFilter(t *testing.T) { + t.Parallel() + got, err := parseTrailStatusFilter(trailListStatusAny) + if err != nil { + t.Fatalf("parseTrailStatusFilter(%q): %v", trailListStatusAny, err) + } + if got != nil { + t.Fatalf("got %v, want nil (any disables the filter)", got) + } +} + +func TestPrintTrailListDefaultRepoShapeShowsAuthor(t *testing.T) { + t.Parallel() + alice := trailListTestAuthorAlice + var out bytes.Buffer + printTrailList(&out, []*trail.Metadata{ + { + Branch: "feat/repo-wide", + Status: trail.StatusInProgress, + Author: &trail.Author{Login: &alice}, + UpdatedAt: time.Now(), + }, + }, trailListDisplayOptions{ + RequestedAuthor: "", + StatusFilters: []trail.Status{trail.StatusInProgress}, + }) + + text := out.String() + for _, want := range []string{"In progress · 1 trail", "feat/repo-wide", trailListTestAuthorAlice} { + if !strings.Contains(text, want) { + t.Fatalf("output missing %q, got:\n%s", want, text) + } + } +} + +func TestPrintTrailListAuthorFilteredShapeHidesAuthor(t *testing.T) { + t.Parallel() + longBranch := "feature/very-long-branch-name-that-must-remain-visible" + alice := trailListTestAuthorAlice + + var out bytes.Buffer + printTrailList(&out, []*trail.Metadata{ + { + Branch: longBranch, + Status: trail.StatusInProgress, + Author: &trail.Author{Login: &alice}, + UpdatedAt: time.Now().Add(-24 * time.Hour), + }, + }, trailListDisplayOptions{ + RequestedAuthor: trailListTestAuthorAlice, + StatusFilters: []trail.Status{trail.StatusInProgress}, + }) + + text := out.String() + if !strings.Contains(text, "alice · 1 in progress") { + t.Fatalf("output should contain author/status header, got:\n%s", text) + } + if !strings.Contains(text, longBranch) { + t.Fatalf("output should contain full branch %q, got:\n%s", longBranch, text) + } + if strings.Count(text, "alice") != 1 { + t.Fatalf("filtered author output should not repeat author in rows, got:\n%s", text) + } +} + +func TestPrintTrailListYourTrailsRelabelsAndSurfacesGhLogin(t *testing.T) { + t.Parallel() + mixedCase := "Alice" // gh returned a different case than the filter + var out bytes.Buffer + printTrailList(&out, []*trail.Metadata{ + { + Branch: "feat/x", + Status: trail.StatusInProgress, + Author: &trail.Author{Login: &mixedCase}, + UpdatedAt: time.Now(), + }, + }, trailListDisplayOptions{ + RequestedAuthor: "alice", + CurrentUser: "alice", + StatusFilters: []trail.Status{trail.StatusInProgress}, + }) + + text := out.String() + if !strings.Contains(text, "Your trails (alice) · 1 in progress") { + t.Fatalf("expected 'Your trails (alice)' header, got:\n%s", text) + } +} + +func TestPrintTrailListAnyAuthorAnyStatusGroupsByStatus(t *testing.T) { + t.Parallel() + alice := trailListTestAuthorAlice + bob := trailListTestAuthorBob + var out bytes.Buffer + printTrailList(&out, []*trail.Metadata{ + {Branch: "feat/a", Status: trail.StatusInProgress, Author: &trail.Author{Login: &alice}, UpdatedAt: time.Now()}, + {Branch: "fix/b", Status: trail.StatusOpen, Author: &trail.Author{Login: &bob}, UpdatedAt: time.Now()}, + }, trailListDisplayOptions{ + RequestedAuthor: "", + StatusFilters: nil, + }) + + text := out.String() + if strings.Index(text, "In progress · 1") > strings.Index(text, "Open · 1") { + t.Fatalf("expected in-progress group before open group, got:\n%s", text) + } + for _, want := range []string{"Recent trails · 2", "In progress · 1", "Open · 1", "feat/a", trailListTestAuthorAlice, "fix/b", trailListTestAuthorBob} { + if !strings.Contains(text, want) { + t.Fatalf("output missing %q, got:\n%s", want, text) + } + } +} + +func TestPrintTrailListSingularRecentTrailWhenOne(t *testing.T) { + t.Parallel() + alice := trailListTestAuthorAlice + var out bytes.Buffer + printTrailList(&out, []*trail.Metadata{ + {Branch: "feat/a", Status: trail.StatusInProgress, Author: &trail.Author{Login: &alice}, UpdatedAt: time.Now()}, + }, trailListDisplayOptions{ + RequestedAuthor: "", + StatusFilters: nil, + }) + + text := out.String() + if !strings.Contains(text, "Recent trail · 1") { + t.Fatalf("expected singular 'Recent trail · 1', got:\n%s", text) + } + if strings.Contains(text, "Recent trails · 1") { + t.Fatalf("did not expect plural 'trails' for count 1, got:\n%s", text) + } +} + +func TestPrintTrailListUnknownStatusGroupedInOtherBucket(t *testing.T) { + t.Parallel() + alice := trailListTestAuthorAlice + unknownStatus := trail.Status("experimental_review") + var out bytes.Buffer + printTrailList(&out, []*trail.Metadata{ + {Branch: "feat/known", Status: trail.StatusInProgress, Author: &trail.Author{Login: &alice}, UpdatedAt: time.Now()}, + {Branch: "feat/odd", Status: unknownStatus, Author: &trail.Author{Login: &alice}, UpdatedAt: time.Now()}, + }, trailListDisplayOptions{ + RequestedAuthor: "", + StatusFilters: nil, + }) + + text := out.String() + for _, want := range []string{"Recent trails · 2", "In progress · 1", "Other · 1", "feat/odd"} { + if !strings.Contains(text, want) { + t.Fatalf("output missing %q, got:\n%s", want, text) + } + } +} + +func TestFetchCurrentUserLoginReturnsLogin(t *testing.T) { + t.Parallel() + r := newFakeRunner() + r.set("gh", []string{"api", "user", "--jq", ".login"}, "octocat\n", nil) + + got, err := fetchCurrentUserLogin(context.Background(), r) + if err != nil { + t.Fatalf("fetchCurrentUserLogin: %v", err) + } + if got != "octocat" { + t.Fatalf("got %q, want octocat", got) + } +} + +func TestFetchCurrentUserLoginRejectsEmptyLogin(t *testing.T) { + t.Parallel() + r := newFakeRunner() + r.set("gh", []string{"api", "user", "--jq", ".login"}, "\n", nil) + + if _, err := fetchCurrentUserLogin(context.Background(), r); err == nil { + t.Fatal("expected error for empty login") + } +} + +func TestFetchCurrentUserLoginWrapsGhError(t *testing.T) { + t.Parallel() + r := newFakeRunner() + r.set("gh", []string{"api", "user", "--jq", ".login"}, "", errors.New("gh: not authenticated")) + + _, err := fetchCurrentUserLogin(context.Background(), r) + if err == nil { + t.Fatal("expected error") + } + // Surface the hint about the --author fallback. + if !strings.Contains(err.Error(), "--author ") { + t.Fatalf("error should mention the --author fallback hint, got: %v", err) + } +} diff --git a/redact/custom.go b/redact/custom.go new file mode 100644 index 0000000..a67f1d8 --- /dev/null +++ b/redact/custom.go @@ -0,0 +1,141 @@ +package redact + +import ( + "log/slog" + "regexp" + "sync" +) + +// CustomRulesConfig configures inline custom_redactions and parsed rule packs. +type CustomRulesConfig struct { + // Inline maps a label (used only in logs/diagnostics) to a Go RE2 regex + // string. Failed compilations are logged via slog.Warn and dropped. + Inline map[string]string + + // Packs are pre-parsed rule packs (see LoadPacks). Per-rule regex + // compilation failures are logged and dropped; sample mismatches are + // logged but do not drop the rule. + Packs []*Pack +} + +// compiledCustomRule is a compiled regex retained across calls. +// label is unused for replacement (custom rules always emit the bare REDACTED +// token to match other secret layers), but is preserved for diagnostics. +type compiledCustomRule struct { + label string + regex *regexp.Regexp +} + +type customRulesState struct { + rules []compiledCustomRule +} + +var ( + customConfig *customRulesState + customConfigMu sync.RWMutex +) + +// componentAttr tags every warning emitted by this package so log aggregators +// can filter redaction failures with the same key the CLI uses elsewhere +// (`logging.WithComponent(ctx, "redaction")`). +var componentAttr = slog.String("component", "redaction") + +// ConfigureCustomRules compiles user-defined redaction rules and stores the +// result for use by redact.String(). Sample-validation runs here too, so +// failures surface the next time any process initializes redaction. +// +// Call once at process startup after loading settings. Thread-safe. +func ConfigureCustomRules(cfg CustomRulesConfig) { + state := &customRulesState{} + + for label, pattern := range cfg.Inline { + compiled, ok := compileCustomRule( + label, + pattern, + "skipping invalid custom_redactions pattern", + slog.String("label", label), + ) + if ok { + state.rules = append(state.rules, compiled) + } + } + + for _, pack := range cfg.Packs { + for _, rule := range pack.Rules { + compiled, ok := compileCustomRule( + pack.Name+"."+rule.ID, + rule.Regex, + "skipping invalid pack rule", + slog.String("pack", pack.sourcePath), + slog.String("rule", rule.ID), + ) + if ok { + state.rules = append(state.rules, compiled) + runRuleSamples(pack, rule, compiled.regex) + } + } + } + + customConfigMu.Lock() + defer customConfigMu.Unlock() + customConfig = state +} + +func compileCustomRule(label, pattern, warning string, attrs ...any) (compiledCustomRule, bool) { + compiled, err := regexp.Compile(pattern) + if err != nil { + all := make([]any, 0, len(attrs)+2) + all = append(all, componentAttr) + all = append(all, attrs...) + all = append(all, slog.String("error", err.Error())) + slog.Warn(warning, all...) + return compiledCustomRule{}, false + } + return compiledCustomRule{label: label, regex: compiled}, true +} + +// runRuleSamples checks each sample against the compiled regex and logs a +// warning per mismatch. Failures never drop the rule — sample validation +// is informational, not gating. +func runRuleSamples(pack *Pack, rule Rule, compiled *regexp.Regexp) { + for i, s := range rule.Samples { + got := compiled.MatchString(s.Input) + if got != s.Redacted { + slog.Warn("redactor pack sample mismatch", + componentAttr, + slog.String("pack", pack.sourcePath), + slog.String("rule", rule.ID), + slog.Int("sample_index", i), + slog.Int("sample_length", len(s.Input)), + slog.Bool("expected", s.Redacted), + slog.Bool("got", got)) + } + } +} + +// getCustomRulesConfig returns the currently-configured custom rules. +// Returns nil if ConfigureCustomRules has never been called. +func getCustomRulesConfig() *customRulesState { + customConfigMu.RLock() + defer customConfigMu.RUnlock() + return customConfig +} + +// detectCustomRules returns tagged regions for every match of every +// configured custom rule. Returns nil if no rules are configured. +// +// All regions use an empty label so they are replaced with the bare +// "REDACTED" token used by the built-in secret layers, not the +// "[REDACTED_